Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1fcd3329
"vscode:/vscode.git/clone" did not exist on "19d19d3a6a8f4db37c665c6eb139e5c961e4b67f"
Commit
1fcd3329
authored
Dec 23, 2024
by
mtgu0705
Browse files
Enable multiply_multiply for Scale_Block_M = 1 for deepseek
parent
e5bc56a4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
26 deletions
+56
-26
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
...iply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
+16
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+26
-14
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp
.../impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
...u/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
+10
-5
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
View file @
1fcd3329
...
...
@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using
BF16
=
ck
::
bhalf_t
;
using
FP8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -55,7 +56,7 @@ using CDEElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
Scale_Block_M
=
1
28
;
static
constexpr
ck
::
index_t
Scale_Block_M
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
128
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
128
;
...
...
@@ -67,8 +68,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
16
,
16
,
16
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
...
...
@@ -187,6 +188,18 @@ int main(int argc, char* argv[])
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A1DataType
>
{
0
,
1.0
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
break
;
case
5
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A1DataType
>
{});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
break
;
case
6
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
A1DataType
>
{});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
-
0.5
,
0.5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
1fcd3329
...
...
@@ -338,11 +338,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_thread_desc
,
make_tuple
(
I0
,
I0
),
a_scale_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
){
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_thread_desc
,
make_tuple
(
m0
,
I0
),
a_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
...
...
@@ -350,7 +357,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple
(
I0
,
I0
),
b_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
...
...
@@ -437,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
a_scale_thread_buf
[
I
0
])
*
type_convert
<
AccDataType
>
(
a_scale_thread_buf
[
m
0
])
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
I0
]);
});
});
...
...
@@ -462,11 +468,18 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_thread_buf
);
});
});
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_thread_desc
,
make_tuple
(
I0
,
I0
),
a_scale_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
){
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_thread_desc
,
make_tuple
(
m0
,
I0
),
a_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
...
...
@@ -474,7 +487,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple
(
I0
,
I0
),
b_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -514,7 +526,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
a_scale_thread_buf
[
I
0
])
*
type_convert
<
AccDataType
>
(
a_scale_thread_buf
[
m
0
])
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
I0
]);
});
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp
View file @
1fcd3329
...
...
@@ -363,10 +363,10 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
return
false
;
}
if
(
ScaleBlockM
%
MPerBlock
!=
0
||
ScaleBlockN
%
NPerBlock
!=
0
||
ScaleBlockK
!=
KPerBlock
)
{
return
false
;
}
//
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
//
{
//
return false;
//
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
View file @
1fcd3329
...
...
@@ -1357,7 +1357,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
const
index_t
ScaleSliceSizeM
=
1
;
const
index_t
ScaleSliceSizeM
=
MXdlPerWave
;
const
index_t
ScaleSliceSizeN
=
1
;
const
index_t
ScaleSliceSizeK
=
1
;
...
...
@@ -1365,20 +1365,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSize
M
>
{},
Number
<
ScaleSliceSizeK
>
{}));
make_tuple
(
Number
<
ScaleSliceSize
N
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
auto
a_thread_offset
=
get_thread_local_1d_id
()
%
MPerXdl
+
(
get_thread_local_1d_id
()
/
64
)
%
MWaves
*
MPerXdl
;
auto
a_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
AScaleType
,
AScaleType
,
decltype
(
a_scale_grid_desc_am_ak
),
decltype
(
a_scale_thread_desc
),
Sequence
<
ScaleSliceSizeM
,
ScaleSliceSizeK
>
,
Sequence
<
1
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
a_scale_grid_desc_am_ak
,
make_multi_index
(
block_m_id
*
MPerBlock
/
ScaleBlockM
,
0
));
a_scale_grid_desc_am_ak
,
make_multi_index
(
block_m_id
*
MPerBlock
/
ScaleBlockM
+
a_thread_offset
,
0
));
auto
b_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
...
...
@@ -1393,7 +1397,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
,
0
));
constexpr
auto
a_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
constexpr
auto
a_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
MWaves
*
MPerXdl
,
0
),
make_multi_index
(
-
MPerBlock
,
1
));
constexpr
auto
b_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment