Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
87ad5225
Commit
87ad5225
authored
Oct 28, 2024
by
aska-0096
Browse files
Bug fix
parent
a75152d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
26 deletions
+46
-26
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
+7
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+29
-16
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
...pu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
+10
-3
No files found.
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
View file @
87ad5225
...
@@ -62,7 +62,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
...
@@ -62,7 +62,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
// static constexpr ck::index_t Scale_Block_M = 128;
// static constexpr ck::index_t Scale_Block_M = 128;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
1
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
128
;
static
constexpr
ck
::
index_t
Scale_Block_K
=
64
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
// clang-format off
// clang-format off
...
@@ -70,18 +70,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
...
@@ -70,18 +70,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
A0DataType
,
B0DataType
,
B1DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
A0DataType
,
B0DataType
,
B1DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
Scale_Block_N
,
Scale_Block_K
,
256
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
128
,
128
,
64
,
// 16, 16,
// 16, 16,
8
,
8
,
8
,
8
,
16
,
16
,
16
,
16
,
4
,
4
,
4
,
4
,
//
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2,
16, 16
, 0,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
//
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2,
16, 16
, 0,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
//
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
//
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
3
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
1
>
;
// clang-format on
// clang-format on
template
<
typename
IntType
>
template
<
typename
IntType
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
87ad5225
...
@@ -346,15 +346,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -346,15 +346,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// make_tuple(I0, I0),
// make_tuple(I0, I0),
// a_scale_thread_buf);
// a_scale_thread_buf);
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_grid_buf
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_desc
,
b_scale_grid_buf
,
make_tuple
(
I0
,
I0
),
b_scale_thread_desc
,
b_scale_thread_buf
);
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
// Local prefill 1
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
...
@@ -470,15 +475,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -470,15 +475,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// make_tuple(I0, I0),
// make_tuple(I0, I0),
// a_scale_thread_buf);
// a_scale_thread_buf);
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_grid_buf
,
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_thread_desc
,
b_scale_grid_buf
,
make_tuple
(
I0
,
I0
),
b_scale_thread_desc
,
b_scale_thread_buf
);
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
// a_scale_thread_copy_step);
// a_scale_thread_copy_step);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
);
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
i
+=
1
;
...
@@ -517,7 +530,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -517,7 +530,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
I
0
]);
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n
0
]);
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
View file @
87ad5225
...
@@ -1383,21 +1383,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1383,21 +1383,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
// a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM,
// a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM,
// 0));
// 0));
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
auto
b_thread_offset
=
get_thread_local_1d_id
()
%
NPerXdl
+
(
get_thread_local_1d_id
()
/
64
)
%
NWaves
*
NPerXdl
;
auto
b_scale_thread_copy
=
auto
b_scale_thread_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
ThreadwiseTensorSliceTransfer_v2
<
BScaleType
,
BScaleType
,
BScaleType
,
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_thread_desc
),
decltype
(
b_scale_thread_desc
),
Sequence
<
ScaleSliceSizeN
,
ScaleSliceSizeK
>
,
Sequence
<
1
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
>
(
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
,
0
));
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr
auto
b_scale_thread_slice_copy_step
=
make_multi_index
(
0
,
1
);
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment