Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bac7df8f
Unverified
Commit
bac7df8f
authored
Aug 17, 2022
by
Chao Liu
Committed by
GitHub
Aug 17, 2022
Browse files
use scale (#363)
parent
c961ce92
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
10 deletions
+18
-10
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-0
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
...softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
+9
-5
example/32_batched_gemm_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_softmax_gemm/CMakeLists.txt
+0
-2
example/CMakeLists.txt
example/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+7
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
0 → 100644
View file @
bac7df8f
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
example/32_batched_gemm_softmax_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp
→
example/32_batched_gemm_
scale_
softmax_gemm/batched_gemm_
scale_
softmax_gemm_xdl_fp16.cpp
View file @
bac7df8f
...
...
@@ -51,7 +51,7 @@ using CLayout = Row;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
...
...
@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType
,
AElementOp
,
B0ElementOp
,
C
ElementOp
>
;
Acc0
ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
...
...
@@ -157,6 +157,7 @@ int main(int argc, char* argv[])
ck
::
index_t
BatchStrideB0
=
-
1
;
ck
::
index_t
BatchStrideB1
=
-
1
;
ck
::
index_t
BatchStrideC
=
-
1
;
float
alpha
=
1
;
if
(
argc
==
1
)
{
...
...
@@ -181,7 +182,7 @@ int main(int argc, char* argv[])
BatchCount
=
std
::
stoi
(
argv
[
8
]);
}
else
if
(
argc
==
1
7
)
else
if
(
argc
==
1
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
...
@@ -203,6 +204,8 @@ int main(int argc, char* argv[])
BatchStrideB0
=
std
::
stoi
(
argv
[
14
]);
BatchStrideB1
=
std
::
stoi
(
argv
[
15
]);
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
alpha
=
std
::
stof
(
argv
[
17
]);
}
else
{
...
...
@@ -211,6 +214,7 @@ int main(int argc, char* argv[])
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC
\n
"
);
printf
(
"arg18: alpha
\n
"
);
exit
(
0
);
}
...
...
@@ -304,7 +308,7 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
...
...
@@ -368,7 +372,7 @@ int main(int argc, char* argv[])
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
PassThrough
{}
);
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
...
...
example/32_batched_gemm_softmax_gemm/CMakeLists.txt
deleted
100644 → 0
View file @
c961ce92
# TODO: add example batched_gemm_gemm_xdl_fp16
add_example_executable
(
example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp
)
example/CMakeLists.txt
View file @
bac7df8f
...
...
@@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory
(
29_batched_gemm_bias_e_permute
)
add_subdirectory
(
30_grouped_convnd_fwd_bias_relu_add
)
add_subdirectory
(
31_batched_gemm_gemm
)
add_subdirectory
(
32_batched_gemm_softmax_gemm
)
add_subdirectory
(
32_batched_gemm_
scale_
softmax_gemm
)
add_subdirectory
(
33_multiple_reduce
)
add_subdirectory
(
34_batchnorm
)
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
bac7df8f
...
...
@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
acc_element_op
)
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
acc_element_op
};
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}
};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
...
...
@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
);
// Acc0 elementwise Op
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment