Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
300ac4e4
Commit
300ac4e4
authored
Jun 02, 2022
by
rocking
Browse files
Implement gemm bias add reduction
parent
09a2b547
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
703 additions
and
77 deletions
+703
-77
example/21_gemm_layernorm/CMakeLists.txt
example/21_gemm_layernorm/CMakeLists.txt
+1
-0
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
..._gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
+415
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
...n/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+73
-31
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+46
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+15
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+153
-46
No files found.
example/21_gemm_layernorm/CMakeLists.txt
View file @
300ac4e4
add_example_executable
(
example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
0 → 100644
View file @
300ac4e4
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_5ary_elementwise.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
C0DataType
=
F32
;
using
C1DataType
=
F32
;
using
GemmAccDataType
=
F32
;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F32
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*
,
DDataType
*>
;
using
GammaDataType
=
F16
;
using
BetaDataType
=
F16
;
using
LayerNormOutDataType
=
F16
;
using
NormalizeComputeDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
ReduceSumOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
ReduceSumOp
,
ReduceSumOp
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
true
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
using
DxsGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmBiasAddReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DxsGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
GemmAccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
using
NormalizeFunctor
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using
DeviceNormalizeInstance
=
ck
::
tensor_operation
::
device
::
Device5AryElementwise
<
CDataType
,
DDataType
,
DDataType
,
GammaDataType
,
BetaDataType
,
LayerNormOutDataType
,
NormalizeComputeDataType
,
NormalizeFunctor
,
2
,
8
,
8
,
// scalarPerVector: gemm_out
1
,
// scalarPerVector: reduce_mean
1
,
// scalarPerVector: reduce_mean_square
8
,
// scalarPerVector: Gamma
8
,
// scalarPerVector: Beta
8
>
;
// scalarPerVector: LayerNorm_out
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
template
<
typename
CDataType
,
typename
DDataType
,
typename
AccDataType
,
typename
C0DataType
,
typename
C1DataType
,
typename
A_functor
,
typename
B_functor
,
typename
C_functor
>
void
host_gemm_layernorm
(
Tensor
<
LayerNormOutDataType
>&
out_m_n
,
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ADataType
>&
b_k_n
,
const
Tensor
<
C0DataType
>&
bias_n
,
const
Tensor
<
C1DataType
>&
c1_m_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
GammaDataType
>&
beta_n
,
A_functor
a_element_op
,
B_functor
b_element_op
,
C_functor
c_element_op
,
int
M
,
int
N
)
{
int
StrideC
=
N
;
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
mean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
meanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
auto
averageOpInst
=
UnaryDivElementOp
{
M
};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
c_m_n
(
m
,
n
))
+
static_cast
<
AccDataType
>
(
bias_n
(
n
));
c_element_op
(
acc
,
acc
);
acc
+=
static_cast
<
AccDataType
>
(
c1_m_n
(
m
,
n
));
c_m_n
(
m
,
n
)
=
static_cast
<
CDataType
>
(
acc
);
}
// reduce_mean and reduce_square_mean
auto
reduceSumOpInst
=
ReduceSumOp
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
mean_acc
=
reduceSumOpInst
.
GetReductionZeroVal
();
float
square_mean_acc
=
reduceSumOpInst
.
GetReductionZeroVal
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
AccDataType
c_val
=
ck
::
type_convert
<
AccDataType
>
(
c_m_n
(
m
,
n
));
AccDataType
square_c_val
=
0
;
UnarySquareElementOp
{}(
square_c_val
,
c_val
);
reduceSumOpInst
(
mean_acc
,
c_val
);
reduceSumOpInst
(
square_mean_acc
,
square_c_val
);
}
averageOpInst
(
mean_acc
,
mean_acc
);
averageOpInst
(
square_mean_acc
,
square_mean_acc
);
mean_m
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
mean_acc
);
meanSquare_m
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
square_mean_acc
);
}
// LayerNorm
auto
layerNormInst
=
NormalizeFunctor
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
out_f32
=
0
;
layerNormInst
(
out_f32
,
c_m_n
(
m
,
n
),
mean_m
(
m
),
meanSquare_m
(
m
),
gamma_n
(
n
),
beta_n
(
n
));
out_m_n
(
m
,
n
)
=
static_cast
<
DDataType
>
(
out_f32
);
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0DataType
,
typename
C1DataType
,
typename
DDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
NormalizeDataType
>
void
DumpGemmLayerNormPerf
(
float
gemm_reduce_time
,
float
normalize_time
,
int
M
,
int
N
,
int
K
)
{
std
::
size_t
gemm_flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
+
std
::
size_t
(
2
)
*
M
*
N
;
std
::
size_t
gemm_num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
C0DataType
)
*
M
*
N
+
sizeof
(
C1DataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
DDataType
)
*
M
;
std
::
size_t
normalize_num_byte
=
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
NormalizeDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
gemm_flop
)
/
1.E9
/
gemm_reduce_time
;
float
gemm_gb_per_sec
=
gemm_num_byte
/
1.E6
/
gemm_reduce_time
;
float
normalize_gb_per_sec
=
normalize_num_byte
/
1.E6
/
normalize_time
;
std
::
cout
<<
"gemm + reduce_mean + reduce_square_mean Perf: "
<<
gemm_reduce_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gemm_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
std
::
cout
<<
"5-ary elementwise Perf: "
<<
normalize_time
<<
" ms, "
<<
normalize_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
int
main
()
{
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideC
=
1024
;
ck
::
index_t
StrideC1
=
1024
;
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor2d
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor2d
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
C0DataType
>
bias_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
C1DataType
>
c1_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
reduceMean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
reduceMeanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
GammaDataType
>
gamma_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
LayerNormOutDataType
>
layerNorm_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
1
,
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
1
,
1
});
bias_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
C0DataType
>
{
-
1
,
1
});
c1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
C1DataType
>
{
-
5
,
5
});
gamma_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
-
1
,
1
});
beta_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
-
1
,
1
});
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
C0DataType
)
*
bias_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c1_device_buf
(
sizeof
(
C1DataType
)
*
c1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
reduceMean_device_buf
(
sizeof
(
DDataType
)
*
reduceMean_m
.
mDesc
.
GetElementSpace
());
DeviceMem
reduceMeanSquare_device_buf
(
sizeof
(
DDataType
)
*
reduceMeanSquare_m
.
mDesc
.
GetElementSpace
());
DeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_n
.
mDesc
.
GetElementSpace
());
DeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
beta_n
.
mDesc
.
GetElementSpace
());
DeviceMem
layerNorm_device_buf
(
sizeof
(
LayerNormOutDataType
)
*
layerNorm_m_n
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_n
.
mData
.
data
());
c1_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
gamma_device_buf
.
ToDevice
(
gamma_n
.
mData
.
data
());
beta_device_buf
.
ToDevice
(
beta_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto
gemmReduce
=
DeviceGemmBiasAddReduceInstance
{};
auto
gemmReduce_invoker
=
gemmReduce
.
MakeInvoker
();
auto
gemmReduce_argument
=
gemmReduce
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
C0DataType
*>
(
bias_device_buf
.
GetDeviceBuffer
()),
static_cast
<
C1DataType
*>
(
c1_device_buf
.
GetDeviceBuffer
()),
dxs_global
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
StrideC1
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs_in_element_op
,
dxs_out_element_op
);
if
(
!
gemmReduce
.
IsSupportedArgument
(
gemmReduce_argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
reduceMean_device_buf
.
SetZero
();
reduceMeanSquare_device_buf
.
SetZero
();
// Prepare LayerNorm
auto
normalize
=
DeviceNormalizeInstance
{};
auto
normalize_invoker
=
normalize
.
MakeInvoker
();
auto
normalize_argument
=
normalize
.
MakeArgument
(
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()),
static_cast
<
GammaDataType
*>
(
gamma_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BetaDataType
*>
(
beta_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LayerNormOutDataType
*>
(
layerNorm_device_buf
.
GetDeviceBuffer
()),
{
M
,
N
},
{
StrideC
,
1
},
{
1
,
0
},
{
1
,
0
},
{
0
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
NormalizeFunctor
{});
if
(
!
normalize
.
IsSupportedArgument
(
normalize_argument
))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"Device5AryElementwise instance, exiting!"
);
}
// run kernel
gemmReduce_invoker
.
Run
(
gemmReduce_argument
,
StreamConfig
{
nullptr
,
false
});
normalize_invoker
.
Run
(
normalize_argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
{
// verification
Tensor
<
LayerNormOutDataType
>
host_layerNorm_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
host_gemm_layernorm
<
CDataType
,
DDataType
,
ReduceAccDataType
>
(
host_layerNorm_m_n
,
a_m_k
,
b_k_n
,
bias_n
,
c1_m_n
,
gamma_n
,
beta_n
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
);
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
layerNorm_m_n
.
mData
,
host_layerNorm_m_n
.
mData
,
"Error: Incorrect results layerNorm_m_n"
,
1e-2
,
1e-2
);
}
{
// evaluate kernel perf
bool
time_kernel
=
true
;
float
gemm_reduce_mean_reduce_square_mean_ave_time
=
gemmReduce_invoker
.
Run
(
gemmReduce_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
normalize_ave_time
=
normalize_invoker
.
Run
(
normalize_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
if
(
time_kernel
)
DumpGemmLayerNormPerf
<
ADataType
,
BDataType
,
CDataType
,
C0DataType
,
C1DataType
,
DDataType
,
GammaDataType
,
BetaDataType
,
LayerNormOutDataType
>
(
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
}
return
pass
?
0
:
1
;
}
include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
300ac4e4
...
...
@@ -7,7 +7,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "gridwise_gemm_
bias_add_
reduce_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
...
...
@@ -23,6 +23,8 @@ template <typename ALayout,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
C0DataType
,
typename
C1DataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
...
...
@@ -68,14 +70,15 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
DPtrsGlobal
,
struct
DeviceGemmBiasAddReduce_Xdl_CShuffle
:
public
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemm
BiasAdd
Reduce_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
0
));
using
C1GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemm
BiasAdd
Reduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
C0DataType
,
C1DataType
,
ReduceAccDataType
,
DPtrsGlobal
,
AElementwiseOperation
,
...
...
@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
DGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
C0DataType
*
p_c0_grid
,
const
C1DataType
*
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c0_grid_
{
p_c0_grid
},
p_c1_grid_
{
p_c1_grid
},
p_ds_grid_
{
p_ds_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c0_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
0
)},
c1_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC1
)},
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
...
...
@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c0_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c0_grid_desc_m_n_
);
c1_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
d_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
MakeDGridDescriptor_MBlock_MPerBlock
(
d_grid_desc_m_
);
}
...
...
@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0DataType
*
p_c0_grid_
;
const
C1DataType
*
p_c1_grid_
;
DPtrsGlobal
p_ds_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
DGridDesc_M
d_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
...
...
@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float
elapsed_time
=
0.0
f
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0DataType
,
C1DataType
,
DPtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
p_ds_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
C0DataType
,
C1DataType
,
DPtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
p_ds_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c0_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
C0DataType
*
p_c0
,
const
C1DataType
*
p_c1
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return
Argument
{
p_a
,
p_b
,
p_c
,
p_c0
,
p_c1
,
p_dxs
,
MRaw
,
NRaw
,
...
...
@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA
,
StrideB
,
StrideC
,
StrideC1
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
DPtrsGlobal
p_dxs
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
C0DataType
*>
(
p_c0
),
static_cast
<
const
C1DataType
*>
(
p_c1
),
p_dxs
,
MRaw
,
NRaw
,
...
...
@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA
,
StrideB
,
StrideC
,
StrideC1
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
300ac4e4
...
...
@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
struct
DeviceGemmBiasAddReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
DPtrsGlobal
p_dxs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC1
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
DxsInElementwiseOperation
dxs_in_element_op
,
DxsAccElementwiseOperation
dxs_out_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
>
using
DeviceGemmBiasAddReducePtr
=
std
::
unique_ptr
<
DeviceGemmBiasAddReduce
<
DPtrsGlobal
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsAccElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
300ac4e4
...
...
@@ -143,6 +143,21 @@ struct AddHardswishAdd
}
};
struct
Relu
{
__host__
__device__
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
};
struct
Normalize
{
Normalize
(
float
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
300ac4e4
...
...
@@ -16,6 +16,8 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
...
...
@@ -25,6 +27,8 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
...
...
@@ -32,10 +36,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_reduce_xdl_cshuffle_v1
(
kernel_gemm_
bias_add_
reduce_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -46,6 +52,10 @@ __global__ void
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
...
...
@@ -55,6 +65,8 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_c0_grid
,
p_c1_grid
,
p_ds_grid
,
p_shared
,
a_element_op
,
...
...
@@ -65,12 +77,16 @@ __global__ void
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c0_grid
;
ignore
=
p_c1_grid
;
ignore
=
p_ds_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
...
...
@@ -80,6 +96,8 @@ __global__ void
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c0_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -89,6 +107,8 @@ template <typename FloatAB,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatReduceAcc
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
...
...
@@ -102,6 +122,8 @@ template <typename FloatAB,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
DGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -138,7 +160,7 @@ template <typename FloatAB,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct
GridwiseGemm
BiasAdd
Reduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
_
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C0GridDesc_M_N
{}))
>
;
using
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C1GridDesc_M_N
{}))
>
;
using
DGridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
MakeDGridDescriptor_MBlock_MPerBlock
(
DGridDesc_M
{}))
>
;
...
...
@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
...
...
@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c0_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c1_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
},
Number
<
p_ds_grid
.
Size
()
>
{});
// c0 and c1
constexpr
auto
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{},
I1
,
Number
<
nreduce_per_thread
>
{}));
constexpr
auto
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
;
auto
c01_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatC0
,
FloatReduceAcc
,
decltype
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
1
,
true
>
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]));
auto
c1_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatC1
,
FloatReduceAcc
,
decltype
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
1
,
true
>
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]));
constexpr
auto
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{},
I1
,
Number
<
nreduce_per_thread
>
{}));
auto
c_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
FloatC
,
decltype
(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
mreduce_per_thread
,
I1
,
nreduce_per_thread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
3
,
// DstVectorDim
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
c_reduce_thread_data_idx_begin
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
...
@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to
read from
LDS
// make sure it's safe to
write to
LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// TODO - extract following into reduction_blockwise
{
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
...
...
@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
c0_thread_copy_global_to_vgpr
.
Run
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_buf
,
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c01_thread_buf
);
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
[
&
](
auto
i
)
{
FloatReduceAcc
out
;
c_element_op
(
out
,
c_reduce_thread_buf
(
i
)
+
c01_thread_buf
(
i
));
c_reduce_thread_buf
(
i
)
=
out
;
});
c1_thread_copy_global_to_vgpr
.
Run
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_buf
,
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c01_thread_buf
);
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
[
&
](
auto
i
)
{
c_reduce_thread_buf
(
i
)
+=
c01_thread_buf
(
i
);
});
c_reduce_thread_copy_vgpr_to_global
.
Run
(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_reduce_thread_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
static_for
<
0
,
p_ds_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_d_grid
=
p_ds_grid
[
In
];
...
...
@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_
shuffle_block
_copy_
lds
_to_global
.
MoveDstSliceWindow
(
c_
reduce_thread
_copy_
vgpr
_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
// Reduction
// move on C0
c0_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on C1
c1_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
// Reduction
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment