Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5215f11d
Commit
5215f11d
authored
Dec 05, 2022
by
rocking
Browse files
Wrtie out the e for debug.
This could be remove and use h for instead
parent
3b97076d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
5 deletions
+45
-5
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+38
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+7
-4
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
5215f11d
...
...
@@ -48,7 +48,6 @@ using BLayout = Col;
using
D0Layout
=
Row
;
using
D1Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
ELayout
=
Row
;
using
HLayout
=
Row
;
using
AElementOp
=
PassThrough
;
...
...
@@ -67,6 +66,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
64
,
4
>
,
4
,
S
<
8
,
32
>
,
S
<
1
,
8
>
,
1
,
8
,
8
,
8
,
8
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
...
...
@@ -88,6 +95,8 @@ auto f_host_tensor_descriptor2d =
int
main
()
{
bool
do_verification
=
true
;
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
...
...
@@ -107,6 +116,7 @@ int main()
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
GammaDataType
>
gamma_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
HDataType
>
e_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideH
,
HLayout
{}));
Tensor
<
HDataType
>
h_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideH
,
HLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
1
,
1
});
...
...
@@ -122,6 +132,7 @@ int main()
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
beta_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
HDataType
)
*
e_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
h_device_buf
(
sizeof
(
HDataType
)
*
h_m_n
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
...
...
@@ -144,6 +155,7 @@ int main()
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
e_device_buf
.
GetDeviceBuffer
(),
h_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
...
...
@@ -164,4 +176,29 @@ int main()
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
if
(
do_verification
)
{
Tensor
<
AccDataType
>
c_m_n_host
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
HDataType
>
e_m_n_host
(
HostTensorDescriptor
{
M
,
N
});
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host
(
m
,
n
),
c_m_n_host
(
m
,
n
),
d0_n
(
n
),
d1_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
)
?
0
:
1
;
}
}
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
5215f11d
...
...
@@ -463,6 +463,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
const
void
*
p_gamma_grid
,
const
void
*
p_beta_grid
,
void
*
p_e_grid
,
void
*
p_h_grid
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -479,7 +480,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
nullptr
},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)
},
p_welford_mean_grid_
{
nullptr
},
p_welford_var_grid_
{
nullptr
},
p_welford_count_grid_
{
nullptr
},
...
...
@@ -509,9 +510,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
// TODO - hipFree
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int
gemm_welford_size
=
MRaw
*
gemm_nblock_
;
hip_check_error
(
hipMalloc
(
&
p_welford_mean_grid_
,
sizeof
(
MeanDataType
)
*
gemm_welford_size
));
...
...
@@ -770,6 +769,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_e
,
void
*
p_h
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -789,6 +789,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds
,
p_gamma
,
p_beta
,
p_e
,
p_h
,
MRaw
,
NRaw
,
...
...
@@ -812,6 +813,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_e
,
void
*
p_h
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -831,6 +833,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds
,
p_gamma
,
p_beta
,
p_e
,
p_h
,
MRaw
,
NRaw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment