Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8ea1c974
Commit
8ea1c974
authored
Jan 09, 2025
by
mtgu0705
Browse files
change the custom int4 to uint8 for verification
parent
d201acc0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
26 deletions
+27
-26
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+3
-3
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+5
-5
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+19
-18
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
View file @
8ea1c974
...
...
@@ -208,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
)
.
data
;
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
);
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
...
...
@@ -303,9 +303,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
pk_i4_t
i4x2
=
b_k_n
(
k
,
n
);
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
.
data
>>
0
)
&
0xf
;
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
.
data
>>
4
)
&
0xf
;
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
v_b
=
ck
::
type_convert
<
float
>
(
i4
);
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
8ea1c974
...
...
@@ -324,11 +324,11 @@ struct PassThrough final : public UnaryOpBase
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
pk_i4_t
,
pk_i4_t
>
(
pk_i4_t
&
y
,
const
pk_i4_t
&
x
)
const
{
y
=
x
;
}
//
template <>
//
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
//
{
//
y = x;
//
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
...
...
include/ck/utility/data_type.hpp
View file @
8ea1c974
...
...
@@ -11,15 +11,16 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
pk_i4_t
=
uint8_t
;
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
uint8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
};
//
struct pk_i4_t
//
{
//
using type = uint8_t;
//
type data;
//
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
//
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
//
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
...
...
@@ -174,12 +175,12 @@ struct scalar_type<int4_t>
};
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
//
template <>
//
struct scalar_type<pk_i4_t>
//
{
//
using type = pk_i4_t;
//
static constexpr index_t vector_size = 1;
//
};
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
...
...
@@ -1060,11 +1061,11 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
//
template <>
//
struct nnvb_data_t_selector<pk_i4_t>
//
{
//
using type = pk_i4_t::type;
//
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment