Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
efab74a3
Commit
efab74a3
authored
Jan 24, 2025
by
Rostyslav Geyyer
Browse files
Merge branch 'gfx950' into lwpck-2619
parents
86950b3a
bcef33c1
Changes
362
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4074 additions
and
58 deletions
+4074
-58
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+11
-6
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+3
-6
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+35
-4
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+13
-7
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+21
-13
include/ck/ck.hpp
include/ck/ck.hpp
+4
-1
include/ck/config.h.in
include/ck/config.h.in
+4
-0
include/ck/library/utility/host_tensor.hpp
include/ck/library/utility/host_tensor.hpp
+66
-12
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+30
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+167
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
+403
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
+1248
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+530
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+686
-0
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+41
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+5
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+11
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+781
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+13
-4
No files found.
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
efab74a3
...
...
@@ -16,7 +16,7 @@
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
b
atched
_g
emm
_ka
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
ck_tile
::
B
atched
G
emm
HostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
@@ -70,20 +70,25 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
);
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
efab74a3
...
...
@@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -53,11 +49,12 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
float
batched_gemm
(
b
atched
_g
emm
_ka
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
batched_gemm
(
const
ck_tile
::
B
atched
G
emm
HostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
efab74a3
...
...
@@ -17,13 +17,15 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
b
atched
_g
emm
_ka
rgs
args
;
ck_tile
::
B
atched
G
emm
HostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
...
@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B
,
batch_stride_C
,
batch_count
,
kbatch
,
n_warmup
,
n_repeat
);
...
...
@@ -188,15 +192,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
...
...
@@ -208,6 +230,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C
,
batch_count
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
efab74a3
...
...
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"group_count"
,
"16"
,
"group count"
);
arg_parser
.
insert
(
"Ms"
,
""
,
"M dimensions - empty by default."
)
.
insert
(
"Ns"
,
""
,
"N dimensions - empty by default."
)
.
insert
(
"Ks"
,
""
,
"K dimensions - empty by default."
)
.
insert
(
"stride_As"
,
""
,
"Tensor A strides - it is empty by default."
)
.
insert
(
"stride_Bs"
,
""
,
"Tensor B strides - it is empty by default."
)
.
insert
(
"stride_Cs"
,
""
,
"Tensor C strides - it is empty by default."
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default."
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default."
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default."
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel."
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel."
)
.
insert
(
"group_count"
,
"16"
,
"group count."
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
efab74a3
...
...
@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
return
-
1
;
};
auto
valid_input_data
=
[
&
](
int
group_count
,
const
auto
&...
args
)
{
return
!
(
args
.
empty
()
||
...
)
&&
group_count
==
(
args
.
size
()
==
...
);
};
const
int
group_count
=
arg_parser
.
get_int
(
"group_count"
);
const
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
const
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
std
::
vector
<
ck_tile
::
index_t
>
Ms
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
;
std
::
vector
<
ck_tile
::
index_t
>
Ms
=
arg_parser
.
get_int_vec
(
"Ms"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
=
arg_parser
.
get_int_vec
(
"Ns"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
=
arg_parser
.
get_int_vec
(
"Ks"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
=
arg_parser
.
get_int_vec
(
"stride_As"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
=
arg_parser
.
get_int_vec
(
"stride_Bs"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
=
arg_parser
.
get_int_vec
(
"stride_Cs"
)
;
if
(
!
valid_input_data
(
group_count
,
Ms
,
Ns
,
Ks
,
stride_As
,
stride_Bs
,
stride_Cs
))
{
std
::
cout
<<
"Please check the input data. Default values will be used."
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
Ms
.
push_back
(
256
+
256
*
i
);
...
...
@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
}
}
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
BDataType
>>
b_k_n_tensors
;
...
...
include/ck/ck.hpp
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -155,6 +155,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
...
...
include/ck/config.h.in
View file @
efab74a3
...
...
@@ -97,6 +97,10 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#ifndef CK_ENABLE_DPP_KERNELS
#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@
#endif
//
// CK kernels which support XDL (MI series)
//
...
...
include/ck/library/utility/host_tensor.hpp
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os
<<
delim
;
if
constexpr
(
std
::
is_same_v
<
T
,
ck
::
f8_t
>
||
std
::
is_same_v
<
T
,
ck
::
bf8_t
>
)
using
RangeType
=
ck
::
remove_cvref_t
<
decltype
(
v
)
>
;
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
f8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bf8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bhalf_t
>
)
{
os
<<
ck
::
type_convert
<
float
>
(
v
);
}
else
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
pk_i4_t
>
)
{
const
auto
packed_floats
=
ck
::
type_convert
<
ck
::
float2_t
>
(
v
);
const
ck
::
vector_type
<
float
,
2
>
vector_of_floats
{
packed_floats
};
os
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
0
>
{}]
<<
delim
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
1
>
{}];
}
else
{
os
<<
static_cast
<
T
>
(
v
);
...
...
@@ -266,18 +275,18 @@ struct Tensor
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
Lengths
>
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
...
...
@@ -287,7 +296,7 @@ struct Tensor
{
}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
...
...
@@ -322,7 +331,17 @@ struct Tensor
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
(
mDesc
.
GetElementSpaceSize
()
+
1
)
/
2
;
}
else
{
return
mDesc
.
GetElementSpaceSize
();
}
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
...
...
@@ -468,31 +487,66 @@ struct Tensor
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
include/ck/library/utility/host_tensor_generator.hpp
View file @
efab74a3
...
...
@@ -93,6 +93,20 @@ struct GeneratorTensor_1<int8_t>
}
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
pk_i4_t
>
{
int8_t
value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
t
=
value
+
8
;
ck
::
pk_i4_t
r
=
((
t
<<
4
)
+
t
)
&
0xff
;
return
r
;
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
@@ -133,6 +147,22 @@ struct GeneratorTensor_2<int8_t>
}
};
template
<
>
struct
GeneratorTensor_2
<
ck
::
pk_i4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
hi
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
int
lo
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
ck
::
pk_i4_t
r
=
((
hi
<<
4
)
+
lo
)
&
0xff
;
return
r
;
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
...
...
include/ck/tensor/static_tensor.hpp
View file @
efab74a3
...
...
@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
...
...
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
// BScale Thread Copy
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
// ABlockCopy
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
// BBlockCopy
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
// CThread
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num_loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
// assume kperblock = scaleblockk
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
});
});
});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
static
constexpr
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
const
BScaleGridDesc
&
b_scale_grid_desc
,
// BScaleThreadCopy
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
ignore
=
num_loop_per_scale
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
auto
c_thread_buf_per_scale
=
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
();
// need?
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc,
// make_tuple(n0, I0),
// b_scale_thread_buf);
// b_scale_thread_copy.MoveSrcSliceWindow(
// b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
// });
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
// b_scale_thread_copy_step.At(Number<1>{}));
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_scale_thread_copy.Run(b_scale_grid_desc,
// b_scale_grid_buf,
// b_scale_thread_desc,
// make_tuple(n0, I0),
// b_scale_thread_buf);
// b_scale_thread_copy.MoveSrcSliceWindow(
// b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
// });
// b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
// b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
MRepeat
*
KPerInnerLoop
>
{},
Number
<
MRepeat
*
KPerInnerLoop
>
{},
I1
));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v3_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
__device__
static
constexpr
auto
HotLoopScheduler
()
{
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
constexpr
auto
num_scale_k_block
=
BScaleThreadDesc
{}.
GetLength
(
I1
);
constexpr
auto
num_scale_krepeat
=
KRepeat
/
num_scale_k_block
;
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
Number
<
n0
*
num_scale_k_block
+
k0
/
num_scale_krepeat
>
{}],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_scale_thread_buf
[
Number
<
n0
*
num_scale_k_block
+
k0
/
num_scale_krepeat
>
{}],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimimal pipeline with highest resource request
// GlobalPrefetchStages: 4
// LocalPreFillStages: 2
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v4_b_scale
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
return
TailNumber
::
Odd
;
}
else
{
return
TailNumber
::
Even
;
}
}
__device__
static
constexpr
void
HotLoopScheduler
()
{
// TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_issue_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_a
=
(
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
num_issue_a
-
1
)
/
num_issue_a
;
constexpr
auto
num_dsread_per_issue_a
=
num_ds_read_inst_a
/
num_issue_a
;
constexpr
auto
num_issue_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_b
=
(
HotLoopInstList
::
B_LDS_Write_Inst_Num
+
num_issue_b
-
1
)
/
num_issue_b
;
constexpr
auto
num_dsread_per_issue_b
=
num_ds_read_inst_b
/
num_issue_b
;
constexpr
auto
num_mfma_per_issue
=
HotLoopInstList
::
C_MFMA_Inst_Num
/
(
num_issue_a
+
num_issue_b
);
static_for
<
0
,
num_issue_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_a
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_issue_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_b
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
,
typename
BScaleGridBuffer
,
typename
BScaleGridDesc
,
typename
BScaleThreadDesc
,
typename
BScaleThreadTransfer
,
typename
BScaleThreadTransferStep
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
// BScaleThreadCopy
const
BScaleGridDesc
&
b_scale_grid_desc
,
const
BScaleThreadDesc
&
b_scale_thread_desc
,
BScaleThreadTransfer
&
b_scale_thread_copy
,
const
BScaleGridBuffer
&
b_scale_grid_buf
,
const
BScaleThreadTransferStep
&
b_scale_thread_copy_step
,
// num loop
index_t
num_loop
,
index_t
num_loop_per_scale
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// B scale buffer
auto
b_scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_scale_thread_desc
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_scale_thread_buf
),
Number
<
2
>
{}
>
b_scale_thread_bufs
;
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I0
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
num_loop_per_scale
==
1
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
));
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I1
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
2
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_scale_thread_bufs
(
I0
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
// Local prefill 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
));
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
I0
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
(
3
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
// B scale copy
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_bufs
(
lds_read_reg_buf
));
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
4
+
mfma_reg_buf
.
value
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
LoopFunc
(
I1
,
I1
,
I0
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_scale_thread_bufs
(
lds_read_buf
)[
n0
],
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
HotLoopScheduler
();
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
};
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
);
ReadCompFunc
(
I0
,
I0
,
I1
);
CompFunc
(
I0
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadCompFunc
(
I1
,
I1
,
I0
);
CompFunc
(
I1
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
efab74a3
...
...
@@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteA
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
template
<
typename
ALayout
,
...
...
@@ -73,6 +77,43 @@ struct DeviceGemmV2R1 : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
BScaleType
,
typename
CDataType
,
index_t
ScaleBlockN
,
index_t
ScaleBlockK
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmV2BScale
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideScaleB
,
const
void
*
p_b_scale
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
100755 → 100644
View file @
efab74a3
...
...
@@ -469,7 +469,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
Streamk_sel
>
0
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
efab74a3
...
...
@@ -64,7 +64,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
DeviceGemm_Xdl_CShuffleV3
:
public
DeviceGemmV2
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
,
PermuteA
,
PermuteB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
...
@@ -633,6 +637,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
index_t
GetKPerBlock
()
override
{
return
KPerBlock
;
}
bool
GetPermuteA
()
override
{
return
PermuteA
;
}
bool
GetPermuteB
()
override
{
return
PermuteB
;
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
BScaleDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
ScaleBlockN
,
// scale block for N
index_t
ScaleBlockK
,
// scale block for K
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
DeviceGemm_Xdl_CShuffleV3
:
public
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
ScaleBlockN
,
ScaleBlockK
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
ScaleBlockN
,
ScaleBlockK
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
PermuteA
,
PermuteB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
KBatch
);
float
ave_time
=
0
;
index_t
k_grain
=
arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
Argument
arg_
=
arg
;
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
arg_
.
M
,
arg_
.
MPadded
,
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
StrideA
,
arg_
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
N
,
arg_
.
NPadded
,
arg_
.
StrideB
,
arg_
.
BK0
);
auto
size_a_buffer
=
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
);
auto
size_b_buffer
=
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
);
ck
::
utility
::
RotatingMemWrapper
<
Argument
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
size_a_buffer
,
size_b_buffer
);
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg_
);
}
else
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
<=
128
*
128
*
64
*
2
)
?
2
:
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
index_t
GetKPerBlock
()
override
{
return
KPerBlock
;
}
bool
GetPermuteB
()
override
{
return
PermuteB
;
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideScaleB
,
const
BScaleDataType
*
p_b_scale
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
StrideScaleB
,
p_b_scale
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideScaleB
,
const
void
*
p_b_scale
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
StrideScaleB
,
static_cast
<
const
BScaleDataType
*>
(
p_b_scale
),
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGemmXdlUniversal"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -1557,14 +1557,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
if
(
!
(
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
&&
const
bool
is_w_pad_zero
=
arg
.
input_left_pads_
[
NDimSpatial
-
1
]
==
0
&&
arg
.
input_right_pads_
[
NDimSpatial
-
1
]
==
0
;
const
auto
X
=
arg
.
filter_spatial_lengths_
[
NDimSpatial
-
1
];
const
bool
XC_access_allowed
=
arg
.
Conv_G_
==
1
&&
(
arg
.
Conv_C_
*
X
)
%
BBlockTransferSrcScalarPerVector
==
0
&&
is_w_pad_zero
;
if
(
!
((
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
||
XC_access_allowed
)
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
))
{
if
(
!
(
arg
.
Conv_K_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideA_
==
1
))
if
(
!
(
arg
.
Conv_K_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideA_
==
1
&&
NumGroupsToMerge
>
1
))
{
return
false
;
}
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideB_
==
1
))
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideB_
==
1
&&
NumGroupsToMerge
>
1
))
{
return
false
;
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment