Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f45cc232
Commit
f45cc232
authored
Jul 11, 2023
by
ltqin
Browse files
first change bias load
parent
87f2bbcf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
67 deletions
+50
-67
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+49
-66
No files found.
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
f45cc232
...
@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
...
@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
M
=
1024
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
f45cc232
...
@@ -79,7 +79,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
...
@@ -79,7 +79,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
index_t
C1ShuffleGemm0NXdlPerWavePerShuffle
,
index_t
C1ShuffleGemm0NXdlPerWavePerShuffle
,
typename
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
int
D0sTransferSrcScalarPerVector
=
4
>
struct
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
struct
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
@@ -710,13 +711,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -710,13 +711,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
I1
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
I
1
,
// MWaveId
m
1
,
// MWaveId
I
1
,
// NWaveId
n
1
,
// NWaveId
I1
,
// MPerXdl
m2
,
// MPerXdl
I1
,
// NGroupNum
n2
,
// NGroupNum
I1
,
// NInputNum
n3
,
// NInputNum
n4
));
// registerNum
n4
));
// registerNum
auto
d0s_thread_buf
=
generate_tuple
(
auto
d0s_thread_buf
=
generate_tuple
(
...
@@ -732,9 +733,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -732,9 +733,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
constexpr
auto
acc0_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
Gemm0MXdlPerWave
>
{},
Number
<
Gemm0NXdlPerWave
>
{},
n2
,
n4
));
auto
d0s_threadwise_copy
=
generate_tuple
(
auto
d0s_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
return
ThreadwiseTensorSliceTransfer_v2
<
return
ThreadwiseTensorSliceTransfer_v2
<
...
@@ -742,10 +740,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -742,10 +740,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
n4
>
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
9
,
n4
,
D0sTransferSrcScalarPerVector
,
1
,
1
,
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
...
@@ -898,66 +905,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -898,66 +905,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
blockwise_gemm0
,
blockwise_gemm0
,
acc0_thread_buf
,
acc0_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// bias+gelu
// multiple d
if
constexpr
(
NumD0Tensor
)
{
{
static_for
<
0
,
Gemm0MXdlPerWave
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
Gemm0NXdlPerWave
,
1
>
{}([
&
](
auto
nr
)
{
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
groupid
)
{
d0s_grid_buf
[
i
],
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0s_threadwise_copy
(
i
).
Run
(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_thread_buf
(
i
));
d0s_grid_buf
[
i
],
});
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
// get reference to src data
d0s_thread_buf
(
i
));
const
auto
src_data_refs
=
generate_tie
(
});
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
},
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i
)
{
Number
<
NumD0Tensor
>
{});
constexpr
index_t
c_offset
=
acc0_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
nr
,
groupid
,
i
));
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// get reference to src data
// return type should be lvalue
const
auto
src_data_refs
=
generate_tie
(
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
i
);
},
// return type should be lvalue
Number
<
2
>
{});
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
},
Number
<
NumD0Tensor
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
-
n2
.
value
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
1
,
-
Gemm0NXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
-
Gemm0MXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
}
}
else
{
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
cde0_element_op
(
acc_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
// gemm1
// gemm1
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
// TODO: explore using dynamic buffer for a1 thread buffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment