Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
980ed33a
"driver/vscode:/vscode.git/clone" did not exist on "dab2938937507f8bbdb2d058e4f989ed7094eac1"
Commit
980ed33a
authored
May 25, 2022
by
rocking
Browse files
Refine deviceop
parent
bb314592
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
63 deletions
+62
-63
include/ck/tensor_operation/gpu/device/device_normalize_xdl_cshuffle.hpp
...or_operation/gpu/device/device_normalize_xdl_cshuffle.hpp
+62
-63
No files found.
include/ck/tensor_operation/gpu/device/device_normalize_xdl_cshuffle.hpp
View file @
980ed33a
...
...
@@ -22,58 +22,58 @@ template <typename XDataType,
typename
OutDataType
,
typename
ComputeDataType
,
typename
OutElementwiseFunctor
,
index_t
Dim
,
index_t
M
0
PerThread
,
index_t
XScalarPerVector
=
M0PerThread
,
index_t
MeanScalarPerVector
=
M0PerThread
,
index_t
MeanSquareScalarPerVector
=
M0PerThread
,
index_t
GammaScalarPerVector
=
M0PerThread
,
index_t
BetaScalarPerVector
=
M0PerThread
>
index_t
N
Dim
,
index_t
MPerThread
,
index_t
XScalarPerVector
,
index_t
MeanScalarPerVector
,
index_t
MeanSquareScalarPerVector
,
index_t
GammaScalarPerVector
,
index_t
BetaScalarPerVector
>
struct
DeviceNormalize_Xdl_CShuffle
:
public
BaseOperator
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
0
>
static
auto
PadDescriptor_M
0
_1d
(
Desc_M
0
desc_m
0
,
index_t
gridSize
,
index_t
blockSize
)
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m
0
=
desc_m
0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
M
0
PerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
0
,
loop_step
)
-
m
0
;
const
auto
desc_m
0
_pad
=
transform_tensor_descriptor
(
desc_m
0
,
make_tuple
(
make_right_pad_transform
(
m
0
,
pad
)),
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m
0
_pad
;
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
0
(
const
std
::
vector
<
index_t
>&
shape
,
static
auto
MakeDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
N
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
N
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
if
constexpr
(
N
Dim
>
1
)
{
const
auto
desc_m
0
=
transform_tensor_descriptor
(
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
N
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M
0
_1d
(
desc_m
0
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M
0
_1d
(
desc
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
using
GridDesc_M
0
=
decltype
(
MakeDescriptor_M
0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
struct
Argument
:
public
BaseArgument
{
...
...
@@ -83,7 +83,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
OutDataType
*
p_output
,
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
stride_x
,
const
std
::
vector
<
index_t
>&
stride_mean
,
const
std
::
vector
<
index_t
>&
stride_mean_square
,
...
...
@@ -97,7 +97,7 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_output_
(
p_output
),
shape_
(
shape
),
lengths_
(
lengths
),
stride_x_
(
stride_x
),
stride_mean_
(
stride_mean
),
stride_mean_square_
(
stride_mean_square
),
...
...
@@ -107,13 +107,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
blockSize_
(
256
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
x_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_x
,
gridSize_
,
blockSize_
);
mean_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_mean
,
gridSize_
,
blockSize_
);
mean_square_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_mean_square
,
gridSize_
,
blockSize_
);
gamma_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_gamma
,
gridSize_
,
blockSize_
);
beta_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_beta
,
gridSize_
,
blockSize_
);
output_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride_output
,
gridSize_
,
blockSize_
);
x_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_x
,
gridSize_
,
blockSize_
);
mean_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_mean
,
gridSize_
,
blockSize_
);
mean_square_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_mean_square
,
gridSize_
,
blockSize_
);
gamma_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_gamma
,
gridSize_
,
blockSize_
);
beta_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_beta
,
gridSize_
,
blockSize_
);
output_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_output
,
gridSize_
,
blockSize_
);
}
const
XDataType
*
p_x_
;
...
...
@@ -122,13 +122,13 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
OutDataType
*
p_output_
;
std
::
vector
<
index_t
>
shape
_
;
GridDesc_M
0
x_grid_desc_m
0
_
;
GridDesc_M
0
mean_grid_desc_m
0
_
;
GridDesc_M
0
mean_square_grid_desc_m
0
_
;
GridDesc_M
0
gamma_grid_desc_m
0
_
;
GridDesc_M
0
beta_grid_desc_m
0
_
;
GridDesc_M
0
output_grid_desc_m
0
_
;
std
::
vector
<
index_t
>
lengths
_
;
GridDesc_M
x_grid_desc_m_
;
GridDesc_M
mean_grid_desc_m_
;
GridDesc_M
mean_square_grid_desc_m_
;
GridDesc_M
gamma_grid_desc_m_
;
GridDesc_M
beta_grid_desc_m_
;
GridDesc_M
output_grid_desc_m_
;
std
::
vector
<
index_t
>
stride_x_
;
std
::
vector
<
index_t
>
stride_mean_
;
std
::
vector
<
index_t
>
stride_mean_square_
;
...
...
@@ -157,18 +157,6 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
}
};
bool
IsScalarPerVectorValid
(
bool
broadcastOnFastest
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
broadcastOnFastest
)
ret
=
scalarPerVector
==
1
;
else
ret
=
M0PerThread
%
scalarPerVector
==
0
;
return
ret
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
...
...
@@ -176,26 +164,37 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
shape
_
.
size
()
!=
Dim
)
if
(
pArg
->
lengths
_
.
size
()
!=
N
Dim
)
return
false
;
if
(
pArg
->
shape
_
.
back
()
%
M
0
PerThread
!=
0
)
if
(
pArg
->
lengths
_
.
back
()
%
MPerThread
!=
0
)
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_x_
.
back
()
==
0
,
XScalarPerVector
))
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
MPerThread
%
scalarPerVector
==
0
;
return
ret
;
};
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_x_
.
back
()
==
1
,
XScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_mean_
.
back
()
==
0
,
MeanScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_mean_
.
back
()
==
1
,
MeanScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_mean_square_
.
back
()
==
0
,
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_mean_square_
.
back
()
==
1
,
MeanSquareScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_gamma_
.
back
()
==
0
,
GammaScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_gamma_
.
back
()
==
1
,
GammaScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_beta_
.
back
()
==
0
,
BetaScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_beta_
.
back
()
==
1
,
BetaScalarPerVector
))
return
false
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment