Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ee6cd44d
Commit
ee6cd44d
authored
Nov 20, 2023
by
rocking
Browse files
implement generic kernel
parent
28d87372
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
400 additions
and
2 deletions
+400
-2
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
+3
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
...ation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
...n/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
+394
-1
No files found.
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
View file @
ee6cd44d
...
@@ -132,6 +132,7 @@ int main()
...
@@ -132,6 +132,7 @@ int main()
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
dy_dev
.
ToDevice
(
dy
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
...
@@ -213,9 +214,11 @@ int main()
...
@@ -213,9 +214,11 @@ int main()
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dgamma_dev
.
FromDevice
(
dgamma
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dbeta_dev
.
FromDevice
(
dbeta
.
mData
.
data
());
dx_dev
.
FromDevice
(
dx
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dgamma
,
host_dgamma
,
"Error: Incorrect dgamma"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dbeta
,
host_dbeta
,
"Error: Incorrect dbeta"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
dx
,
host_dx
,
"Error: Incorrect dx"
,
1e-3
,
1e-3
);
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
View file @
ee6cd44d
...
@@ -265,7 +265,9 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
...
@@ -265,7 +265,9 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
Make2dDescriptor
(
lengths_
,
invStdStrides_
,
numBlockTileIteration_
);
Make2dDescriptor
(
lengths_
,
invStdStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
dy_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
K_BlockTileSize
;
// TODO - sweep once for small k
// isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
isSweeponce_
=
false
;
}
}
const
DYDataType
*
p_dy_
;
const
DYDataType
*
p_dy_
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
View file @
ee6cd44d
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment