Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a7f24bfc
Commit
a7f24bfc
authored
Jan 02, 2025
by
mtgu0705
Browse files
Add unit test for int4 weight only
parent
8b83b087
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
80 deletions
+13
-80
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+12
-80
No files found.
example/01_gemm/CMakeLists.txt
View file @
a7f24bfc
...
...
@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
View file @
a7f24bfc
...
...
@@ -22,6 +22,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
true
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
...
...
@@ -45,7 +46,7 @@ using DeviceGemmV2Instance =
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
PermuteA
,
PermuteB
>
;
// clang-format on
...
...
@@ -82,21 +83,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size
_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index
_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is
zero
, return a default packed stride
// give a chance if stride is
-1
, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
return
static_cast
<
std
::
size_t
>
(
col
)
;
}
else
{
return
row
;
return
static_cast
<
std
::
size_t
>
(
row
)
;
}
}
else
return
stride
;
return
static_cast
<
std
::
size_t
>
(
stride
)
;
};
ck
::
index_t
Scale_Stride_BN
=
(
K
+
Scale_Block_K
-
1
)
/
Scale_Block_K
;
...
...
@@ -164,11 +165,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
b1_scale_device_buf
(
sizeof
(
BScaleDataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
printf
(
"b_k_n element space size: %zu, b_k_n device size: %lu, BDataType size: %lu
\n
"
,
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
(),
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
(),
sizeof
(
BDataType
));
// weight permute
if
constexpr
(
PermuteB
)
{
...
...
@@ -199,7 +195,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
// vector pk_i4x4 permute
#if 1
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
...
...
@@ -208,7 +203,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
);
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
)
.
data
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
...
...
@@ -247,7 +242,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
}
#endif
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
...
...
@@ -287,9 +281,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return
true
;
}
std
::
size_t
workspace_size
=
gemm
.
GetWorkSpaceSize
(
&
argument
);
printf
(
"workspace_size: %zu
\n
"
,
workspace_size
);
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
...
...
@@ -300,12 +291,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
ck
::
pk_i4_t
i4x2
=
b_k_n
(
k
,
n
);
ck
::
pk_i4_t
i4x2
=
b_k_n
(
k
,
n
)
.
data
;
int8_t
i4
=
0
;
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
i4
=
(
i4x2
.
data
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
(
i4x2
.
data
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
v_b
=
ck
::
type_convert
<
float
>
(
i4
);
...
...
@@ -331,65 +322,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout<<"scale_b1_k_n: "<<std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < (K + Scale_Block_K - 1) / Scale_Block_K; j++)
{
std::cout << ck::type_convert<float>(b1_k_n(j,i)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
}
if
(
config
.
time_kernel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment