Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
988478d4
Commit
988478d4
authored
Dec 26, 2024
by
chenjun
Browse files
edit fp8 ab scale for Scale_Block_M=1
parent
f728087c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
119 additions
and
91 deletions
+119
-91
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
...iply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
+13
-24
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+43
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
...u/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
+63
-56
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
View file @
988478d4
...
@@ -26,7 +26,6 @@ using S = ck::Sequence<Is...>;
...
@@ -26,7 +26,6 @@ using S = ck::Sequence<Is...>;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
FP8
=
ck
::
f8_t
;
using
FP8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -68,11 +67,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
...
@@ -68,11 +67,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
128
,
128
,
16
,
16
,
128
,
16
,
16
,
32
,
32
,
16
,
16
,
2
,
2
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
// clang-format on
// clang-format on
...
@@ -83,9 +82,9 @@ int main(int argc, char* argv[])
...
@@ -83,9 +82,9 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
128
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideB
=
K
;
...
@@ -101,7 +100,7 @@ int main(int argc, char* argv[])
...
@@ -101,7 +100,7 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
7
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -111,9 +110,9 @@ int main(int argc, char* argv[])
...
@@ -111,9 +110,9 @@ int main(int argc, char* argv[])
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
])
;
StrideA
=
K
;
StrideB
=
std
::
stoi
(
argv
[
8
])
;
StrideB
=
K
;
StrideE
=
std
::
stoi
(
argv
[
9
])
;
StrideE
=
N
;
}
}
else
else
{
{
...
@@ -185,20 +184,10 @@ int main(int argc, char* argv[])
...
@@ -185,20 +184,10 @@ int main(int argc, char* argv[])
case
4
:
case
4
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A1DataType
>
{
0
,
1.0
});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A1DataType
>
{
0
,
1.0
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
// b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break
;
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
case
5
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A1DataType
>
{});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
break
;
case
6
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
a1_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
1.0
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
A1DataType
>
{});
break
;
break
;
default:
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
-
0.5
,
0.5
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
-
0.5
,
0.5
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
988478d4
...
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
true
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
true
>
;
using
Base
::
I0
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
xdlops_gemm
;
...
@@ -338,18 +340,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -338,18 +340,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
){
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_grid_buf
,
a_scale_thread_desc
,
a_scale_thread_desc
,
make_tuple
(
m0
,
I0
),
make_tuple
(
m0
,
I0
),
a_scale_thread_buf
);
a_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
});
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
if
(
num_loop_per_scale
==
1
)
{
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_grid_buf
,
...
@@ -357,6 +373,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -357,6 +373,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_buf
);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
// Local prefill 1
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
...
@@ -468,18 +485,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -468,18 +485,32 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_thread_buf
);
b_thread_buf
);
});
});
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
){
// a_scale_thread_copy.Run(a_scale_grid_desc,
// a_scale_grid_buf,
// a_scale_thread_desc,
// make_tuple(I0, I0),
// a_scale_thread_buf);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_thread_copy
.
Run
(
a_scale_grid_desc
,
a_scale_grid_buf
,
a_scale_grid_buf
,
a_scale_thread_desc
,
a_scale_thread_desc
,
make_tuple
(
m0
,
I0
),
make_tuple
(
m0
,
I0
),
a_scale_thread_buf
);
a_scale_thread_buf
);
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
a_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
});
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
if
(
num_loop_per_scale
==
1
)
{
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
a_scale_thread_copy
.
MoveSrcSliceWindow
(
a_scale_grid_desc
,
a_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_grid_buf
,
...
@@ -487,6 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -487,6 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_buf
);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp
View file @
988478d4
...
@@ -1363,16 +1363,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1363,16 +1363,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
constexpr
auto
a_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
a_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
auto
a_thread_offset
=
get_thread_local_1d_id
()
%
MPerXdl
+
(
get_thread_local_1d_id
()
/
64
)
/
NWaves
*
MPerXdl
;
// auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 128) * MPerXdl;
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
make_tuple
(
Number
<
ScaleSliceSizeM
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
// auto a_thread_offset =
// get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) % MWaves * MPerXdl;
auto
a_thread_offset
=
get_thread_local_1d_id
()
%
MPerXdl
+
(
get_thread_local_1d_id
()
/
128
)
*
MPerXdl
;
auto
a_scale_thread_copy
=
auto
a_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
AScaleType
,
ThreadwiseTensorSliceTransfer_v2
<
AScaleType
,
AScaleType
,
AScaleType
,
...
@@ -1384,7 +1384,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1384,7 +1384,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
1
,
1
,
1
,
1
,
false
>
(
false
>
(
a_scale_grid_desc_am_ak
,
make_multi_index
(
block_m_id
*
MPerBlock
/
ScaleBlockM
+
a_thread_offset
,
0
));
a_scale_grid_desc_am_ak
,
make_multi_index
(
block_m_id
*
MPerBlock
/
ScaleBlockM
+
a_thread_offset
,
0
));
auto
b_scale_thread_copy
=
auto
b_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
...
@@ -1399,8 +1400,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1399,8 +1400,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
false
>
(
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
,
0
));
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
,
0
));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr
auto
a_scale_thread_slice_copy_step
=
constexpr
auto
a_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
MWaves
*
MPerXdl
,
0
),
make_multi_index
(
-
MPerBlock
,
1
));
make_tuple
(
make_multi_index
(
MWaves
*
MPerXdl
,
0
),
make_multi_index
(
-
MPerBlock
,
0
),
make_multi_index
(
-
MPerBlock
,
1
));
constexpr
auto
b_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
constexpr
auto
b_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
@@ -1443,24 +1447,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1443,24 +1447,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// transposed XDL
// // TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// // TODO: hacky, fix it!
// only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I5
);
constexpr
auto
N3
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I6
);
constexpr
auto
N4
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I7
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
...
@@ -1469,24 +1477,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1469,24 +1477,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
static_cast
<
CShuffleDataType
*>
(
p_shared
),
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M2
)),
// M2 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
N2
,
// N2 * N3 * N4 = NPerXdl
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
Sequence
<>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
@@ -1496,57 +1504,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1496,57 +1504,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_
m3_m4_
adaptor
=
const
auto
m_thread_data_on_block_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_
m3_m4_
adaptor
.
CalculateBottomIndex
(
m_thread_data_on_block_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
const
auto
n_thread_data_on_block_to_n0_n1_n2_
n3_n4_
adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
n_thread_data_on_block_to_n0_n1_n2_
n3_n4_
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
I1
,
I1
,
M2
,
I1
,
I1
,
M4
,
N2
,
I1
>
,
I1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
7
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
true
>
{
true
>
{
c_block_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
,
c_block_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
,
make_multi_index
(
0
,
make_multi_index
(
0
,
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I2
],
m
_thread_data_on_block_idx
[
I
3
],
n
_thread_data_on_block_idx
[
I
2
],
m
_thread_data_on_block_idx
[
I
4
],
n
_thread_data_on_block_idx
[
I
3
],
n_thread_data_on_block_idx
[
I
2
]),
n_thread_data_on_block_idx
[
I
4
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
using
EDataType
=
CDataType
;
using
EDataType
=
CDataType
;
...
@@ -1628,18 +1636,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1628,18 +1636,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M
2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
1
,
N
2
,
1
,
N4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
1
,
1
,
M2
,
1
,
1
,
M4
,
N2
,
1
>>
{};
1
,
N4
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
@@ -1659,10 +1666,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
...
@@ -1659,10 +1666,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
block_sync_lds
();
block_sync_lds
();
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_
m3_m4
_n
2
,
c_block_desc_m0_n0_m1_n1_m2_
n2_n3
_n
4
,
c_shuffle_block_buf
);
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
// make sure it's safe to read from LDS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment