Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f91dad8e
Commit
f91dad8e
authored
Aug 22, 2022
by
Anthony Chang
Browse files
scaling
parent
a7e00533
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
9 deletions
+12
-9
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
...gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
+10
-7
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
...softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
+2
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
f91dad8e
...
...
@@ -54,7 +54,7 @@ using CPermuteNumDims_G_M_O =
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
...
...
@@ -126,7 +126,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType
,
AElementOp
,
B0ElementOp
,
C
ElementOp
>
;
Acc0
ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
...
...
@@ -159,6 +159,7 @@ int main(int argc, char* argv[])
ck
::
index_t
BatchStrideA
=
-
1
;
ck
::
index_t
BatchStrideB0
=
-
1
;
ck
::
index_t
BatchStrideB1
=
-
1
;
float
alpha
=
1
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
...
...
@@ -178,7 +179,7 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
9
)
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
...
@@ -190,14 +191,16 @@ int main(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
7
: M, N, K, O,
Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O,
G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
exit
(
0
);
}
...
...
@@ -292,7 +295,7 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
...
...
@@ -359,7 +362,7 @@ int main(int argc, char* argv[])
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
PassThrough
{}
);
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
View file @
f91dad8e
...
...
@@ -212,9 +212,9 @@ int main(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
7
: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
printf
(
"arg4 to 1
6
: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC
\n
"
);
printf
(
"arg1
8
: alpha
\n
"
);
printf
(
"arg1
7
:
scale (
alpha
)
\n
"
);
exit
(
0
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment