Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c13776be
Commit
c13776be
authored
Dec 05, 2022
by
rocking
Browse files
1. Allocate mean, var and count into by SetWorkSpacePointer.
2. Add GetWorkSpaceSize to calculate the space size
parent
5215f11d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
38 deletions
+83
-38
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+4
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+79
-38
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
c13776be
...
@@ -175,6 +175,10 @@ int main()
...
@@ -175,6 +175,10 @@ int main()
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
}
}
size_t
workspace_sz
=
device_op
.
GetWorkSpaceSize
(
&
argument
);
DeviceMem
workspace_dev
(
workspace_sz
);
device_op
.
SetWorkSpacePointer
(
&
argument
,
workspace_dev
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
if
(
do_verification
)
if
(
do_verification
)
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
c13776be
...
@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_w
elford_mean_grid
_
{
nullptr
},
p_w
orkspace_mean
_
{
nullptr
},
p_w
elford_var_grid
_
{
nullptr
},
p_w
orkspace_var
_
{
nullptr
},
p_w
elford
_count_
grid_
{
nullptr
},
p_w
orkspace
_count_
{
nullptr
},
p_gamma_grid_
{
static_cast
<
const
GammaDataType
*>
(
p_gamma_grid
)},
p_gamma_grid_
{
static_cast
<
const
GammaDataType
*>
(
p_gamma_grid
)},
p_beta_grid_
{
static_cast
<
const
BetaDataType
*>
(
p_beta_grid
)},
p_beta_grid_
{
static_cast
<
const
BetaDataType
*>
(
p_beta_grid
)},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
...
@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_
=
mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int
gemm_welford_size
=
MRaw
*
gemm_nblock_
;
hip_check_error
(
hipMalloc
(
&
p_welford_mean_grid_
,
sizeof
(
MeanDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_welford_var_grid_
,
sizeof
(
VarDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_welford_count_grid_
,
sizeof
(
int32_t
)
*
gemm_welford_size
));
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
...
@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemmWelford
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemmWelford
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
MeanDataType
*
p_welford_mean_grid
_
;
void
*
p_workspace_mean
_
;
VarDataType
*
p_welford_var_grid
_
;
void
*
p_workspace_var
_
;
int32_t
*
p_welford
_count_
grid_
;
void
*
p_workspace
_count_
;
const
GammaDataType
*
p_gamma_grid_
;
const
GammaDataType
*
p_gamma_grid_
;
const
BetaDataType
*
p_beta_grid_
;
const
BetaDataType
*
p_beta_grid_
;
HDataType
*
p_h_grid_
;
HDataType
*
p_h_grid_
;
...
@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
p_welford
_mean_
grid_
,
static_cast
<
MeanDataType
*>
(
arg
.
p_workspace
_mean_
)
,
arg
.
p_welford_var_grid_
,
static_cast
<
VarDataType
*>
(
arg
.
p_workspace_var_
)
,
arg
.
p_welford
_count_
grid_
,
static_cast
<
int32_t
*>
(
arg
.
p_workspace
_count_
)
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
...
@@ -703,15 +695,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -703,15 +695,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t
numNormBlockTileIteration_N
=
index_t
numNormBlockTileIteration_N
=
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
));
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
p_welford
_mean_
grid_
,
static_cast
<
const
MeanDataType
*>
(
arg
.
p_workspace
_mean_
)
,
arg
.
p_welford_var_grid_
,
static_cast
<
const
VarDataType
*>
(
arg
.
p_workspace_var_
)
,
arg
.
p_welford
_count_
grid_
,
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace
_count_
)
,
arg
.
p_gamma_grid_
,
arg
.
p_gamma_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_h_grid_
,
arg
.
p_h_grid_
,
...
@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
}
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
int
gemm_welford_size
=
pArg_
->
mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
MeanDataType
)
+
64
;
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
VarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
gemm_welford_size
*
sizeof
(
int32_t
)
+
64
;
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
int
gemm_welford_size
=
pArg_
->
mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
gemm_welford_size
*
sizeof
(
MeanDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
gemm_welford_size
*
sizeof
(
VarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
pArg_
->
p_workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
};
static
bool
IsSupportedArgument
(
const
Argument
&
)
static
bool
IsSupportedArgument
(
const
Argument
&
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment