Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eead0864
Commit
eead0864
authored
Dec 12, 2022
by
rocking
Browse files
Share E and H memory in device op
parent
a1cc1504
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
12 deletions
+1
-12
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+0
-6
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+1
-6
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
eead0864
...
...
@@ -163,7 +163,6 @@ int main()
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
GammaDataType
>
gamma_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
HDataType
>
e_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideH
,
HLayout
{}));
Tensor
<
HDataType
>
h_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideH
,
HLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
1
,
1
});
...
...
@@ -179,7 +178,6 @@ int main()
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
beta_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
HDataType
)
*
e_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
h_device_buf
(
sizeof
(
HDataType
)
*
h_m_n
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
...
...
@@ -202,7 +200,6 @@ int main()
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
e_device_buf
.
GetDeviceBuffer
(),
h_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
...
...
@@ -250,10 +247,7 @@ int main()
N
,
epsilon
);
e_device_buf
.
FromDevice
(
e_m_n
.
mData
.
data
());
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
,
"Error: Incorrect results e_m_n"
);
pass
&=
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
}
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
eead0864
...
...
@@ -499,7 +499,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
const
void
*
p_gamma_grid
,
const
void
*
p_beta_grid
,
void
*
p_e_grid
,
void
*
p_h_grid
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -516,7 +515,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_
e
_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_
h
_grid
)},
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
...
...
@@ -938,7 +937,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_e
,
void
*
p_h
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -958,7 +956,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds
,
p_gamma
,
p_beta
,
p_e
,
p_h
,
MRaw
,
NRaw
,
...
...
@@ -982,7 +979,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_e
,
void
*
p_h
,
index_t
MRaw
,
index_t
NRaw
,
...
...
@@ -1002,7 +998,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds
,
p_gamma
,
p_beta
,
p_e
,
p_h
,
MRaw
,
NRaw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment