Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
009e76c9
Commit
009e76c9
authored
Oct 31, 2024
by
mtgu0705
Browse files
Support int8_dequant interwave scheduler
parent
b07e21a4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
494 additions
and
26 deletions
+494
-26
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
+2
-2
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+469
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
...pu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
+22
-23
No files found.
example/65_gemm_multiply_multiply/gemm_fp16int4_b_scale.cpp
View file @
009e76c9
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
//
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
...
@@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
PermuteB
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
EDataType
,
EDataType
,
false
,
PermuteB
>
;
// clang-format on
// clang-format on
...
...
example/65_gemm_multiply_multiply/gemm_fp16int8_b_scale.cpp
View file @
009e76c9
...
@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
...
@@ -81,7 +81,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v
3
>
;
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v
1
>
;
// clang-format on
// clang-format on
template
<
typename
IntType
>
template
<
typename
IntType
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
View file @
009e76c9
...
@@ -434,4 +434,473 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -434,4 +434,473 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
// Pipeline with inter-wave scheduling
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
static
constexpr
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
// ABlockCopy
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
// BBlockCopy
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
//CThread
CThreadBuffer
&
c_thread_buf
,
//BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num_loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// B Scale copy
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no
// observable negative impact. The desired effect is waves in a workgroup
// executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
// resource from other workgroups and reducing the chance of latency hiding by
// waiting for the rest of the workgroup at the eventual sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
// applying small delays to different wavefronts It is performed
// near the end of MAC cluster to minimize lgkmcnt penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
I0
));
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
// c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
}
protected:
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
MRepeat
*
KPerInnerLoop
>
{},
Number
<
MRepeat
*
KPerInnerLoop
>
{},
I1
));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
// using Base::a_thread_copy_;
// using Base::a_thread_desc_;
// using Base::b_thread_copy_;
// using Base::b_thread_desc_;
// using Base::c_thread_desc_;
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_scale.hpp
View file @
009e76c9
...
@@ -1443,29 +1443,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
...
@@ -1443,29 +1443,28 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_main_loop
,
num_k_block_per_scale
);
num_k_block_per_scale
);
// shuffle C and write out
// shuffle C and write out
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment