Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dde01029
Commit
dde01029
authored
May 31, 2022
by
Anthony Chang
Browse files
clang-format
parent
597155e8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
23 deletions
+41
-23
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
..._gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
+14
-11
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
...eration/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
+8
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+18
-7
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+1
-1
No files found.
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
View file @
dde01029
...
@@ -26,11 +26,11 @@ using F32 = float;
...
@@ -26,11 +26,11 @@ using F32 = float;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
C0DataType
=
F16
;
using
C0DataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
CShuffleDataType
=
F16
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -39,7 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -39,7 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
struct
Relu
struct
Relu
{
{
template
<
typename
OutT
,
typename
InT
>
template
<
typename
OutT
,
typename
InT
>
__host__
__device__
void
operator
()(
OutT
&
y
,
const
InT
&
x
)
const
__host__
__device__
void
operator
()(
OutT
&
y
,
const
InT
&
x
)
const
{
{
y
=
x
>
0
?
x
:
0
;
y
=
x
>
0
?
x
:
0
;
...
@@ -187,10 +187,10 @@ int main(int argc, char* argv[])
...
@@ -187,10 +187,10 @@ int main(int argc, char* argv[])
c0_gamma_buf
.
ToDevice
(
c0_n_gamma
.
mData
.
data
());
c0_gamma_buf
.
ToDevice
(
c0_n_gamma
.
mData
.
data
());
c0_beta_buf
.
ToDevice
(
c0_n_beta
.
mData
.
data
());
c0_beta_buf
.
ToDevice
(
c0_n_beta
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
acc_element_op
=
AccElementOp
{};
auto
acc_element_op
=
AccElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
...
@@ -262,8 +262,11 @@ int main(int argc, char* argv[])
...
@@ -262,8 +262,11 @@ int main(int argc, char* argv[])
}
}
else
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F16
>::
value
)
else
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F16
>::
value
)
{
{
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
,
"Error: Incorrect results c"
,
1e-2
,
1e-2
);
c_m_n_host_result
.
mData
,
"Error: Incorrect results c"
,
1e-2
,
1e-2
);
}
}
}
}
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
dde01029
...
@@ -462,8 +462,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -462,8 +462,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -519,8 +521,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -519,8 +521,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
}
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
dde01029
...
@@ -37,7 +37,7 @@ __global__ void
...
@@ -37,7 +37,7 @@ __global__ void
kernel_gemm_layernorm_xdl_cshuffle_v1
(
kernel_gemm_layernorm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
// MxN
FloatC
*
__restrict__
p_c_grid
,
// MxN
const
FloatC0
*
__restrict__
p_c0_bias_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_bias_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_gamma_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_gamma_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_beta_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_beta_grid
,
// 1xN
...
@@ -218,15 +218,20 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -218,15 +218,20 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
// Align 16 bytes (maximum LDS read/write width)
// Align 16 bytes (maximum LDS read/write width)
constexpr
auto
c_block_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
c_block_size_aligned
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
math
::
integer_least_multiple
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
// LDS allocation for reduction workspace
// LDS allocation for reduction workspace
constexpr
index_t
c_lds_workspace_size
=
BlockSize
;
constexpr
index_t
c_lds_workspace_size
=
BlockSize
;
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
FloatAB
),
c_block_size_aligned
*
sizeof
(
FloatCShuffle
)
+
c_lds_workspace_size
*
sizeof
(
FloatReduceAcc
));
c_block_size_aligned
*
sizeof
(
FloatCShuffle
)
+
c_lds_workspace_size
*
sizeof
(
FloatReduceAcc
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -738,11 +743,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -738,11 +743,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
// Align 16 bytes (maximum LDS read/write width)
// Align 16 bytes (maximum LDS read/write width)
constexpr
auto
c_block_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
c_block_size_aligned
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
math
::
integer_least_multiple
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
),
16
)
/
sizeof
(
FloatCShuffle
);
auto
d_reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d_reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
FloatReduceAcc
*>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
)
+
c_block_size_aligned
),
BlockSize
);
reinterpret_cast
<
FloatReduceAcc
*>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
)
+
c_block_size_aligned
),
BlockSize
);
// Sum thread workspace
// Sum thread workspace
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
dde01029
...
@@ -149,7 +149,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -149,7 +149,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
RunLayernorm
(
arg
.
c_m_n_
,
acc_m_n
,
arg
.
c0_n_bias_
,
arg
.
c0_n_gamma_
,
arg
.
c0_n_beta_
);
RunLayernorm
(
arg
.
c_m_n_
,
acc_m_n
,
arg
.
c0_n_bias_
,
arg
.
c0_n_gamma_
,
arg
.
c0_n_beta_
);
arg
.
c_m_n_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
){
arg
.
c_m_n_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
c_element_op_
(
self
(
idx
[
0
],
idx
[
1
]),
self
(
idx
[
0
],
idx
[
1
]));
arg
.
c_element_op_
(
self
(
idx
[
0
],
idx
[
1
]),
self
(
idx
[
0
],
idx
[
1
]));
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment