Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b010b095
Commit
b010b095
authored
Jun 19, 2023
by
aska-0096
Browse files
part2 of previous commit
parent
43777959
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
439 additions
and
43 deletions
+439
-43
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
...m_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
...gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
+340
-0
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
...ched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
+0
-0
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
...tched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
+0
-0
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+86
-30
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+3
-3
No files found.
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
b010b095
...
...
@@ -119,9 +119,9 @@ bool pool_test(bool do_verification,
{
N
,
C
,
Hi
,
Wi
},
{
Y
,
X
},
{
N
,
C
,
Ho
,
Wo
},
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{},
{},
{},
window_strides
,
input_left_pads
,
input_right_pads
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
b010b095
...
...
@@ -161,6 +161,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
CElementOp
>
;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute
_wmma
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
b010b095
...
...
@@ -283,6 +283,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
CElementOp
>
;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute
_wmma
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
b010b095
...
...
@@ -327,6 +327,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
CElementOp
>
;
#include "run_cross_attention.inc"
#include "run_cross_attention
_wmma
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
0 → 100644
View file @
b010b095
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
120
;
ck
::
index_t
N
=
1000
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
128
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
float
alpha
=
1
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
13
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
case
4
:
// A, B0, B1 1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 b0; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: a b0 ; unit: B1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a b1 ; unit: b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
8
:
// Rand: a ; unit: b0 b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
9
:
// Rand: b0 ; unit: a b1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
10
:
// Rand: b1 ; unit: a b0
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
float
best_perf
=
.0
;
float
best_time
=
.0
;
int
not_pass
=
0
;
std
::
string
best_kernel
=
""
;
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_conv_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_conv_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
G0
,
G1
,
alpha
,
input_permute
,
output_permute
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
// return 0;
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
tflops
>
best_perf
)
{
best_perf
=
tflops
;
best_time
=
ave_time
*
1000
;
best_kernel
=
gemm
.
GetTypeString
();
}
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
// gemm 0
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
// softmax
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
acc0_g_m_n
,
a1_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// gemm1
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
// default absolute error and relative error is 0.001
double
rtol
=
1
e
-
3
;
double
atol
=
1
e
-
3
;
// when BF16 is taken, set absolute error and relative error to 0.01
if
(
std
::
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B0DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
B1DataType
,
ck
::
bhalf_t
>
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
rtol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
bool
this_run_verification
=
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
rtol
,
atol
);
printf
(
"Verification: %s, Pass: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
,
this_run_verification
?
"YES"
:
"NO"
);
if
(
!
this_run_verification
)
{
not_pass
=
1
;
printf
(
"%d th MHA instance verification Failed
\n
"
,
i
.
value
);
}
}
});
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
std
::
cout
<<
"Best kernel: "
<<
best_kernel
<<
" , "
<<
best_perf
<<
" TFlops , "
<<
best_time
<<
" us"
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
<<
std
::
endl
;
return
not_pass
;
}
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention
_wmma
.inc
View file @
b010b095
File moved
example/32_batched_gemm_scale_softmax_gemm/run_self_attention.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_self_attention
_wmma
.inc
View file @
b010b095
File moved
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
b010b095
...
...
@@ -283,6 +283,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
CElementOp
>
;
#include "run_self_attention.inc"
#include "run_self_attention
_wmma
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
b010b095
...
...
@@ -252,16 +252,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
AK1
>
{});
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides_vec
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths_vec
,
b_gs_ns_ks_strides_vec
),
...
...
@@ -269,8 +269,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
...
...
@@ -453,14 +453,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumD0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumD1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD1Tensor
>&
...
...
@@ -835,20 +835,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
C1DEElementwiseOperation
c1de_element_op
)
{
constexpr
auto
dimension
=
NumDimG
+
NumDimM
+
NumDimN
;
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_lengths_
{};
// b1_gs_os_ns_lengths
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_strides_
{};
// b1_gs_os_ns_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
p_acc0_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
a_gs_ms_ks_lengths
_
,
a_gs_ms_ks_strides
_
,
b_gs_ns_ks_lengths
_
,
b_gs_ns_ks_strides
_
,
b1_gs_gemm1ns_gemm1ks_lengths
_
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
_
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
_
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
_
,
// c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
...
...
@@ -891,20 +919,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation
b1_element_op
,
C1DEElementwiseOperation
c1de_element_op
)
override
{
constexpr
auto
dimension
=
NumDimG
+
NumDimM
+
NumDimN
;
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
a_gs_ms_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_lengths_
{};
std
::
array
<
index_t
,
dimension
>
b_gs_ns_ks_strides_
{};
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_lengths_
{};
// b1_gs_os_ns_lengths
std
::
array
<
index_t
,
dimension
>
b1_gs_gemm1ns_gemm1ks_strides_
{};
// b1_gs_os_ns_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
a_gs_ms_ks_lengths
_
,
a_gs_ms_ks_strides
_
,
b_gs_ns_ks_lengths
_
,
b_gs_ns_ks_strides
_
,
b1_gs_gemm1ns_gemm1ks_lengths
_
,
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides
_
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
_
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
_
,
// c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
b010b095
...
...
@@ -119,10 +119,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
AEnableLds
,
B0EnableLds
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
LoopSched
,
AEnableLds
,
B0EnableLds
>
())
>
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
b010b095
...
...
@@ -15,10 +15,10 @@ enum struct PipelineVersion
};
template
<
PipelineVersion
PipelineVer
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
{
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment