Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
03c2ba3a
"include/ck/utility/statically_indexed_array.hpp" did not exist on "b491ebf38480bc0d6cb329ba6825dee610c59097"
Commit
03c2ba3a
authored
Dec 04, 2024
by
aska-0096
Browse files
bug fix + performance opt + clangformat
parent
1a324dfb
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
372 additions
and
315 deletions
+372
-315
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+25
-25
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+21
-21
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+48
-50
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+57
-36
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+21
-24
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+24
-6
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+25
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+63
-68
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
...ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
+15
-15
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
...scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
+13
-4
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
+40
-40
profiler/src/profile_gemm_b_scale.cpp
profiler/src/profile_gemm_b_scale.cpp
+18
-7
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
View file @
03c2ba3a
...
...
@@ -92,29 +92,6 @@ constexpr auto BlockGemmPipeline_Selector()
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -135,9 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v
5
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v
4
)
{
return
BlockwiseGemmXdlops_pipeline_v
5
<
BlkGemmPipeSche
,
return
BlockwiseGemmXdlops_pipeline_v
4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -158,6 +135,29 @@ constexpr auto BlockGemmPipeline_Selector()
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
//BScale Thread Copy
//
BScale Thread Copy
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
...
...
@@ -209,7 +209,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -546,25 +546,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -719,16 +719,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
const
BScaleGridDesc
&
b_scale_grid_desc
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
ignore
=
num_loop_per_scale
;
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
...
...
@@ -751,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
...
...
@@ -864,16 +864,16 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
// });
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -983,10 +983,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
//B scale buffer
//
B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
...
...
@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
...
...
@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// need actually?
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_per_scale
;
// Local prefetch 1
block_sync_lds
();
...
...
@@ -415,10 +429,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
//BScaleThreadCopy
//
BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
//num loop
//
num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
...
...
@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
//B scale buffer
//
B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
...
...
@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
...
...
@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
@@ -437,32 +436,32 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_bufs
(
mfma_reg_buf
)[
n0
]);
type_convert
<
AccDataType
>
(
b_scale_thread_bufs
(
mfma_reg_buf
)[
n0
]);
});
});
});
...
...
@@ -513,15 +512,14 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
@@ -595,10 +593,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
03c2ba3a
...
...
@@ -99,7 +99,7 @@ struct DeviceGemmV2BScale : public BaseOperator
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC
,
const
void
*
p_b_scale
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -35,8 +35,8 @@ template <typename ALayout,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
ScaleBlockN
,
// scale block for N
index_t
ScaleBlockK
,
// scale block for K
index_t
ScaleBlockN
,
// scale block for N
index_t
ScaleBlockK
,
// scale block for K
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
...
...
@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
>
128
*
128
*
64
*
2
)
?
1
:
2
:
2
;
if
(
has_main_k_block_loop
)
{
...
...
@@ -659,12 +664,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t
StrideB
,
index_t
StrideC
,
const
BScaleDataType
*
p_b_scale
,
index_t
KBatch
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
p_b_scale
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
p_b_scale
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -680,7 +698,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t
StrideB
,
index_t
StrideC
,
const
void
*
p_b_scale
,
index_t
KBatch
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
03c2ba3a
...
...
@@ -13,9 +13,9 @@ namespace ck {
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
expr
int
LO
=
0x000f000f
;
const
expr
int
HI
=
0x00f000f0
;
const
expr
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
...
...
@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0xE408E408
;
//-8
const
int
MUL
=
0x2c002c00
;
// 1/16
const
int
ADD
=
0xd480d480
;
//-79
const
expr
int
SUB
=
0xE408E408
;
//-8
const
expr
int
MUL
=
0x2c002c00
;
// 1/16
const
expr
int
ADD
=
0xd480d480
;
//-79
vector_type
<
half_t
,
4
>
res
;
...
...
@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
#if 0
asm volatile("v_and_or_b32 %0, %4, %5, %7 \n \
v_and_or_b32 %1, %4, %6, %7 \n \
v_pk_add_f16 %2, %0, %8 \n \
v_pk_fma_f16 %3, %1, %9, %10 \
"
: "=v"(lo), "=v"(hi), "=v"(res.template AsType<half2_t>()(Number<0>{})), "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(q), "v"(LO), "v"(HI), "s"(EX), "s"(SUB), "v"(MUL), "s"(ADD), "0"(lo), "1"(hi));
#endif
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
...
...
@@ -80,14 +88,14 @@ struct PassThroughPack8
{
#if 1
int
x_permute
=
0
;
int
bits4_0
=
(
bit_cast
<
int
>
(
x
)
>>
0
)
&
0xF
;
int
bits4_1
=
(
bit_cast
<
int
>
(
x
)
>>
4
)
&
0xF
;
int
bits4_2
=
(
bit_cast
<
int
>
(
x
)
>>
8
)
&
0xF
;
int
bits4_3
=
(
bit_cast
<
int
>
(
x
)
>>
12
)
&
0xF
;
int
bits4_4
=
(
bit_cast
<
int
>
(
x
)
>>
16
)
&
0xF
;
int
bits4_5
=
(
bit_cast
<
int
>
(
x
)
>>
20
)
&
0xF
;
int
bits4_6
=
(
bit_cast
<
int
>
(
x
)
>>
24
)
&
0xF
;
int
bits4_7
=
(
bit_cast
<
int
>
(
x
)
>>
28
)
&
0xF
;
int
bits4_0
=
(
bit_cast
<
int
>
(
x
)
>>
0
)
&
0xF
;
int
bits4_1
=
(
bit_cast
<
int
>
(
x
)
>>
4
)
&
0xF
;
int
bits4_2
=
(
bit_cast
<
int
>
(
x
)
>>
8
)
&
0xF
;
int
bits4_3
=
(
bit_cast
<
int
>
(
x
)
>>
12
)
&
0xF
;
int
bits4_4
=
(
bit_cast
<
int
>
(
x
)
>>
16
)
&
0xF
;
int
bits4_5
=
(
bit_cast
<
int
>
(
x
)
>>
20
)
&
0xF
;
int
bits4_6
=
(
bit_cast
<
int
>
(
x
)
>>
24
)
&
0xF
;
int
bits4_7
=
(
bit_cast
<
int
>
(
x
)
>>
28
)
&
0xF
;
x_permute
|=
(
bits4_1
<<
0
);
x_permute
|=
(
bits4_3
<<
4
);
...
...
@@ -111,7 +119,7 @@ struct PassThroughPack8
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
...
...
@@ -125,7 +133,7 @@ struct PassThroughPack8
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -48,12 +48,7 @@ __global__ void
// karg);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -113,8 +108,8 @@ template <typename ALayout,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
ScaleBlockN
,
// scale N
index_t
ScaleBlockK
,
// scale K
index_t
ScaleBlockN
,
// scale N
index_t
ScaleBlockK
,
// scale K
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
...
...
@@ -605,7 +600,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
const
BScaleType
*
p_b_scale_grid_
,
const
BScaleType
*
p_b_scale_grid_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
...
...
@@ -636,7 +631,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
const
BScaleType
*
p_b_scale_grid
;
const
AElementwiseOperation
a_element_op
;
const
BElementwiseOperation
b_element_op
;
...
...
@@ -1303,13 +1298,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
ScaleBlockK
),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
,
0
));
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
...
...
@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock
);
// b scale
static_assert
(
KPerBlock
<=
ScaleBlockK
);
static_assert
(
KPerBlock
<=
ScaleBlockK
);
const
index_t
ScaleSliceSizeN
=
NXdlPerWave
;
const
index_t
ScaleSliceSizeK
=
1
;
constexpr
auto
b_scale_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}
,
Number
<
1
>
{}
));
make_tuple
(
Number
<
ScaleSliceSizeN
>
{},
Number
<
ScaleSliceSizeK
>
{}));
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
...
...
@@ -1455,41 +1447,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType
,
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
b_scale_thread_desc
),
Sequence
<
1
,
ScaleSliceSizeK
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
1
,
ScaleSliceSizeK
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
b_scale_grid_desc_bn_ak
,
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
,
0
));
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
,
0
),
make_multi_index
(
-
NPerBlock
,
KPerBlock
/
ScaleBlockK
,
KPerBlock
%
ScaleBlockK
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_per_scale
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_per_scale
);
// shuffle C and write out
{
...
...
@@ -1756,7 +1750,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
...
...
@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
bit_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
...
@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
bit_cast
<
BDataType
*>
(
bit_cast
<
char
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
...
...
@@ -1924,28 +1918,29 @@ struct GridwiseGemm_xdl_cshuffle_v3
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_per_scale
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
b_scale_grid_desc_bn_ak
,
b_scale_thread_desc
,
b_scale_thread_copy
,
b_scale_grid_buf
,
b_scale_thread_slice_copy_step
,
num_k_block_main_loop
,
num_k_block_per_scale
);
// shuffle C and write out
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
03c2ba3a
...
...
@@ -1176,7 +1176,7 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
View file @
03c2ba3a
...
...
@@ -39,20 +39,20 @@ template <typename ADataType,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
index_t
ScaleBlockK
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
1
,
ScaleBlockK
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
index_t
ScaleBlockK
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
1
,
ScaleBlockK
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
...
...
@@ -70,7 +70,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
View file @
03c2ba3a
...
...
@@ -8,13 +8,22 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
//device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
//
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
...
...
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
View file @
03c2ba3a
...
...
@@ -30,24 +30,24 @@ template <typename ADataType,
typename
ComputeDataType
,
typename
AccDataType
,
typename
CDataType
,
index_t
ScaleBlockK
,
index_t
ScaleBlockK
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
profile_gemm_b_scale_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
{
bool
pass
=
true
;
...
...
@@ -66,24 +66,25 @@ bool profile_gemm_b_scale_impl(int do_verification,
};
ck
::
index_t
Scale_Stride_BN
=
ck
::
is_same_v
<
BLayout
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
>
?
((
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
)
:
N
;
?
((
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
)
:
N
;
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BScaleDataType
>
b1_k_n
(
f_host_tensor_descriptor
((
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
,
// K direction group size is ScaleBlockK
N
,
// N direction group size is 1
Scale_Stride_BN
,
BLayout
{}));
Tensor
<
BScaleDataType
>
b1_k_n
(
f_host_tensor_descriptor
(
(
K
+
ScaleBlockK
-
1
)
/
ScaleBlockK
,
// K direction group size is ScaleBlockK
N
,
// N direction group size is 1
Scale_Stride_BN
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
int
total_gemm_needed
=
a_m_k
.
GetElementSpaceSizeInBytes
()
+
b_k_n
.
GetElementSpaceSizeInBytes
()
+
b1_k_n
.
GetElementSpaceSizeInBytes
();
int
rotating_count
=
std
::
max
(
int
rotating_count
=
std
::
max
(
1
,
std
::
min
(
n_iter
,
static_cast
<
int
>
(
std
::
ceil
(
static_cast
<
double
>
(
rotating
)
/
total_gemm_needed
))));
...
...
@@ -167,9 +168,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
i4
=
i4
-
8
;
v_b
=
ck
::
type_convert
<
float
>
(
i4
);
b_k_n_dequant
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
v_b
)
*
ck
::
type_convert
<
float
>
(
b1_k_n
(
k
/
ScaleBlockK
,
n
));
b_k_n_dequant
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
v_b
)
*
ck
::
type_convert
<
float
>
(
b1_k_n
(
k
/
ScaleBlockK
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
@@ -291,21 +291,21 @@ bool profile_gemm_b_scale_impl(int do_verification,
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
BScaleDataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
kbatch_curr
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
static_cast
<
BScaleDataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
kbatch_curr
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
profiler/src/profile_gemm_b_scale.cpp
View file @
03c2ba3a
...
...
@@ -32,8 +32,8 @@ enum struct GemmDataType
enum
struct
BScaleBlockTile
{
K_64
,
// 0
K_128
,
// 1
K_64
,
// 0
K_128
,
// 1
};
#define OP_NAME "gemm_b_scale"
...
...
@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
14
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
15
]);
printf
(
"M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
);
printf
(
"M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
);
int
n_warmup
=
1
;
int
n_iter
=
10
;
...
...
@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[])
return
pass
?
0
:
1
;
};
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_64)
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
// B_scale_block == BScaleBlockTile::K_64)
// {
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{}, Row{});
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{},
// Row{});
// }
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
&&
B_scale_block
==
BScaleBlockTile
::
K_128
)
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
&&
B_scale_block
==
BScaleBlockTile
::
K_128
)
{
printf
(
"F16_I4_F16 MK_NK_MN K_128
\n
"
);
return
profile
(
F16
{},
I4
{},
F16
{},
F16
{},
F32
{},
F16
{},
ck
::
Number
<
128
>
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
I4
{},
F16
{},
F16
{},
F32
{},
F16
{},
ck
::
Number
<
128
>
{},
Row
{},
Col
{},
Row
{});
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment