Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9eb97340
Unverified
Commit
9eb97340
authored
Dec 04, 2024
by
Mingtao Gu
Committed by
GitHub
Dec 04, 2024
Browse files
Merge pull request #1 from ROCm/i4_update
extend support KPerBlock <= ScaleBlockK
parents
40054f53
1a324dfb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
30 deletions
+21
-30
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+4
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+14
-11
No files found.
example/01_gemm/CMakeLists.txt
View file @
9eb97340
...
@@ -31,6 +31,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
...
@@ -31,6 +31,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
)
target_compile_options
(
example_gemm_xdl_fp16_pk_i4_v3_b_scale PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
View file @
9eb97340
...
@@ -27,7 +27,7 @@ static constexpr bool PermuteB = false;
...
@@ -27,7 +27,7 @@ static constexpr bool PermuteB = false;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
128
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
128
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
static
constexpr
ck
::
index_t
KPerBlock
=
64
;
// clang-format off
// clang-format off
using
DeviceGemmV2Instance
=
using
DeviceGemmV2Instance
=
...
@@ -35,29 +35,16 @@ using DeviceGemmV2Instance =
...
@@ -35,29 +35,16 @@ using DeviceGemmV2Instance =
ALayout
,
BLayout
,
CLayout
,
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
#if 0
128, Scale_Block_N, Scale_Block_K,
16, 128,
KPerBlock, 8, 32,
16, 16,
1, 4,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#else
256
,
Scale_Block_N
,
Scale_Block_K
,
256
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
128
,
KPerBlock
,
8
,
32
,
KPerBlock
,
8
,
32
,
32
,
32
,
32
,
32
,
2
,
2
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
#endif
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
9eb97340
...
@@ -328,7 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -328,7 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_grid_buf
,
b_scale_thread_desc
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
...
@@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_grid_buf
,
b_scale_thread_desc
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_thread_copy
.
MoveSrcSliceWindow
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
9eb97340
...
@@ -713,7 +713,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -713,7 +713,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
...
@@ -849,7 +849,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -849,7 +849,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
...
@@ -1303,8 +1303,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1303,8 +1303,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer
// B Scale grid and buffer
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
ScaleBlockK
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
,
0
));
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
...
@@ -1435,11 +1438,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1435,11 +1438,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock
);
KPerBlock
);
// b scale
// b scale
static_assert
(
KPerBlock
<=
ScaleBlockK
);
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeK
=
1
;
const
index_t
ScaleSliceSizeK
=
1
;
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}
,
Number
<
1
>
{}
));
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
...
@@ -1451,17 +1455,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1451,17 +1455,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType
,
BScaleType
,
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_thread_desc
),
decltype
(
b_scale_thread_desc
),
Sequence
<
1
,
ScaleSliceSizeK
>
,
Sequence
<
1
,
ScaleSliceSizeK
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
>
(
false
>
(
b_scale_grid_desc_bn_ak
,
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
,
0
));
constexpr
auto
b_scale_thread_slice_copy_step
=
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
,
0
),
make_multi_index
(
-
NPerBlock
,
KPerBlock
/
ScaleBlockK
,
KPerBlock
%
ScaleBlockK
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
@@ -1478,13 +1483,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1478,13 +1483,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_main_loop
,
num_k_block_per_scale
);
num_k_block_per_scale
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment