Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
80f038a0
Commit
80f038a0
authored
May 26, 2022
by
root
Browse files
Post-merge fixes
parent
bb1f8082
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
51 deletions
+64
-51
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+59
-47
No files found.
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
View file @
80f038a0
...
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
...
...
@@ -48,11 +49,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
1
];
++
n
)
for
(
std
::
size_t
k
=
0
;
k
<
shape
[
2
];
++
k
)
{
ComputeDataType
a_val
=
static_cas
t
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
b_val
=
static_cas
t
<
ComputeDataType
>
(
B
(
m
,
n
,
k
));
ComputeDataType
a_val
=
ck
::
type_conver
t
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
b_val
=
ck
::
type_conver
t
<
ComputeDataType
>
(
B
(
m
,
n
,
k
));
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
C
(
m
,
n
,
k
)
=
static_cas
t
<
ctype
>
(
c_val
);
C
(
m
,
n
,
k
)
=
ck
::
type_conver
t
<
ctype
>
(
c_val
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
80f038a0
...
...
@@ -93,39 +93,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
ScalarPerVector
=
Number
<
4
>
{};
static
constexpr
auto
MPerThread
=
Number
<
4
>
{};
static
constexpr
auto
AScalarPerVector
=
Number
<
4
>
{};
static
constexpr
auto
BScalarPerVector
=
Number
<
4
>
{};
static
constexpr
auto
CScalarPerVector
=
Number
<
4
>
{};
template
<
typename
Desc_M
0
>
static
auto
PadDescriptor_M
0
_1d
(
Desc_M
0
desc_m
0
,
index_t
gridSize
,
index_t
blockSize
)
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m
0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m
0
_pad
=
transform_tensor_descriptor
(
desc_m
0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m
0
_pad
;
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
0
(
const
std
::
vector
<
in
t
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
static
auto
MakeDescriptor_M
(
const
std
::
vector
<
in
dex_t
>&
lengths
,
const
std
::
vector
<
in
dex_
t
>&
stride
s
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
2
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
2
>
{});
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
1
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
s
[
I
];
},
Number
<
1
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
2
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
...
...
@@ -395,7 +393,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
GridDesc_M
0
=
decltype
(
MakeDescriptor_M
0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
C
GridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -492,13 +490,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
0
_
=
DeviceOp
::
MakeDescriptor_M
0
({
MRaw
,
NRaw
},
{
StrideC
,
I1
},
grid_size
,
BlockSize
);
c_grid_desc_m_
=
DeviceOp
::
MakeDescriptor_M
({
MRaw
,
NRaw
},
{
StrideC
,
I1
},
grid_size
,
BlockSize
);
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
0
_
=
DeviceOp
::
MakeDescriptor_M
0
({
MRaw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
c_grid_desc_m_
=
DeviceOp
::
MakeDescriptor_M
({
MRaw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
}
p_aux_2_grid_
=
p_workspace
+
c_grid_desc_m_n_
.
GetElementSpaceSize
();
...
...
@@ -516,7 +514,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
GridDesc_M
0
c_grid_desc_m
0
_
;
C
GridDesc_M
c_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -556,27 +554,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Add
,
ScalarPerVector
>
;
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
using
GridwiseBinSubstract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Substract
,
ScalarPerVector
>
;
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Add
>
;
const
auto
substract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSubstract
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Substract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
@@ -637,9 +649,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_real_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Substract
{});
ave_time
+=
...
...
@@ -685,9 +697,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_imag_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Add
{});
}
else
...
...
@@ -748,9 +760,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_real_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Substract
{});
ave_time
+=
...
...
@@ -796,9 +808,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_c_grid_imag_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
Add
{});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment