Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b41e6019
Commit
b41e6019
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Merge branch 'develop' into feature/add-permute-device-op
parents
d356c871
868e5c55
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
861 additions
and
735 deletions
+861
-735
client_example/06_softmax/CMakeLists.txt
client_example/06_softmax/CMakeLists.txt
+2
-0
client_example/06_softmax/softmax4d.cpp
client_example/06_softmax/softmax4d.cpp
+150
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+1
-0
example/23_softmax/softmax_blockwise.cpp
example/23_softmax/softmax_blockwise.cpp
+23
-17
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+5
-0
example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
..._gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
+1
-6
include/ck/ck.hpp
include/ck/ck.hpp
+11
-0
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+32
-13
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+3
-6
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+73
-200
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
.../device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+71
-240
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+47
-229
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+272
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+3
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+61
-11
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+20
-0
library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp
...lude/ck/library/tensor_operation_instance/gpu/softmax.hpp
+71
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
...r_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+13
-3
No files found.
client_example/06_softmax/CMakeLists.txt
0 → 100644
View file @
b41e6019
add_executable
(
client_softmax4d softmax4d.cpp
)
target_link_libraries
(
client_softmax4d PRIVATE composable_kernel::device_operations
)
client_example/06_softmax/softmax4d.cpp
0 → 100644
View file @
b41e6019
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
2
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
std
::
vector
<
ck
::
index_t
>
in_lengths
{
2
,
8
,
128
,
1024
};
std
::
vector
<
ck
::
index_t
>
in_strides
{
8
*
128
*
1024
,
128
*
1024
,
1024
,
1
};
std
::
vector
<
ck
::
index_t
>
reduce_dims
{
2
,
3
};
ck
::
index_t
num_elements
=
std
::
accumulate
(
in_lengths
.
begin
(),
in_lengths
.
end
(),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
AccDataType
alpha
{
2.0
f
};
AccDataType
beta
{
2.0
f
};
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
num_elements
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
num_elements
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
if
(
op_ptr
->
GetRank
()
!=
Rank
||
op_ptr
->
GetNumReduceDim
()
!=
NumReduceDim
)
{
continue
;
}
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in_lengths
,
in_strides
,
reduce_dims
,
&
alpha
,
&
beta
,
in
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
num_elements
*
sizeof
(
InDataType
)
+
(
beta
==
0.0
f
?
1
:
2
)
*
num_elements
*
sizeof
(
OutDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in_lengths
,
in_strides
,
reduce_dims
,
&
alpha
,
&
beta
,
in
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
\ No newline at end of file
client_example/CMakeLists.txt
View file @
b41e6019
...
@@ -11,3 +11,4 @@ add_subdirectory(02_gemm_add_add_fastgelu)
...
@@ -11,3 +11,4 @@ add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory
(
03_gemm_layernorm
)
add_subdirectory
(
03_gemm_layernorm
)
add_subdirectory
(
04_contraction
)
add_subdirectory
(
04_contraction
)
add_subdirectory
(
05_layernorm
)
add_subdirectory
(
05_layernorm
)
add_subdirectory
(
06_softmax
)
example/23_softmax/softmax_blockwise.cpp
View file @
b41e6019
...
@@ -9,37 +9,41 @@
...
@@ -9,37 +9,41 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_softmax
_impl
.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
using
namespace
ck
;
using
namespace
ck
::
tensor_operation
::
device
;
using
namespace
ck
::
tensor_operation
::
device
;
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
3
;
constexpr
int
Rank
=
3
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
DeviceSoftmax
<
InDataType
,
using
DeviceInstance
=
DeviceSoftmaxImpl
<
InDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
PassThrough
,
// InElementwiseOperation
NumReduceDim
,
PassThrough
,
// AccElementwiseOperation
256
,
// BlockSize
Rank
,
8
,
// ClusterM
NumReduceDim
,
32
,
// ClusterK
256
,
// BlockSize
1
,
// SliceM
8
,
// ClusterM
8
,
// SliceK
32
,
// ClusterK
1
,
// SrcVecDim (0=M, 1=K)
1
,
// SliceM
8
,
// SrcScalarPerVector
8
,
// SliceK
8
>
;
// OutScalarPerVector
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
8
>
;
// OutScalarPerVector
static
struct
option
long_options
[]
=
{{
"inLengths"
,
required_argument
,
nullptr
,
'D'
},
static
struct
option
long_options
[]
=
{{
"inLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
...
@@ -196,7 +200,7 @@ int main(int argc, char* argv[])
...
@@ -196,7 +200,7 @@ int main(int argc, char* argv[])
if
(
args
.
do_verification
)
if
(
args
.
do_verification
)
{
{
using
ReferenceInstance
=
using
ReferenceInstance
=
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
ReferenceInstance
ref
;
ReferenceInstance
ref
;
auto
ref_arg
=
ref
.
MakeArgument
(
in
,
out_ref
,
alpha
,
beta
,
reduceDims
);
auto
ref_arg
=
ref
.
MakeArgument
(
in
,
out_ref
,
alpha
,
beta
,
reduceDims
);
auto
invoker
=
ref
.
MakeInvoker
();
auto
invoker
=
ref
.
MakeInvoker
();
...
@@ -220,7 +224,9 @@ int main(int argc, char* argv[])
...
@@ -220,7 +224,9 @@ int main(int argc, char* argv[])
&
alpha
,
&
alpha
,
&
beta
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
out_dev
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
b41e6019
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
add_custom_target
(
example_batched_gemm_scale_softmax_gemm
)
add_dependencies
(
example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16
)
example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
View file @
b41e6019
...
@@ -49,14 +49,9 @@ using B0Layout = Col;
...
@@ -49,14 +49,9 @@ using B0Layout = Col;
using
B1Layout
=
Row
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
// When using padded DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle kernel, 2 specs should be set:
// 1. GemmSpecialization should be set to MNPadding(or NPadding in future)
// 2. Acc0ElementOp should be set to ScaleAndResetNaNToMinusInfinity
// Otherwise, wrong result may be produced.
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
AndResetNaNToMinusInfinity
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
...
...
include/ck/ck.hpp
View file @
b41e6019
...
@@ -144,6 +144,17 @@
...
@@ -144,6 +144,17 @@
// workaround: compiler gnerating inefficient ds_write instructions
// workaround: compiler gnerating inefficient ds_write instructions
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
// (gfx908 only) workaround: compiler crash in fused kernels on mainline #9110; #10738 seems ok
// error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
// VGPR_32: Cannot scavenge register without an emergency spill slot!"
// this fall back to less ideal way of handle NPadding in fused attention kernel
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 1
#else
// for __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 0
#endif // __gfx908__
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
#define CK_WORKAROUND_SWDEV_325164 0
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
b41e6019
...
@@ -16,7 +16,8 @@ template <index_t BlockSize,
...
@@ -16,7 +16,8 @@ template <index_t BlockSize,
typename
AccDataType
,
typename
AccDataType
,
typename
ThreadMap_M_K
,
// thread_id to m_k
typename
ThreadMap_M_K
,
// thread_id to m_k
typename
ThreadClusterDesc_M_K
,
typename
ThreadClusterDesc_M_K
,
typename
ThreadSliceDesc_M_K
>
typename
ThreadSliceDesc_M_K
,
bool
IgnoreNaN
=
false
>
struct
BlockwiseSoftmax
struct
BlockwiseSoftmax
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
...
@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
using
ThreadSliceDesc_M
=
decltype
(
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
using
ThreadwiseMaxReduce
=
typename
conditional
<
ThreadSliceDesc_M_K
,
IgnoreNaN
,
ThreadSliceDesc_M
,
ThreadwiseReduction
<
AccDataType
,
reduce
::
Max
,
ThreadSliceDesc_M_K
,
false
>
;
ThreadSliceDesc_M
,
reduce
::
Max
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max
,
false
>>::
type
;
using
ThreadwiseSumReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
>>::
type
;
using
ThreadClusterLengths_M_K
=
decltype
(
ThreadClusterDesc_M_K
{}.
GetLengths
());
using
ThreadClusterLengths_M_K
=
decltype
(
ThreadClusterDesc_M_K
{}.
GetLengths
());
...
@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
...
@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
reduce
::
Add
,
reduce
::
Add
,
false
>
;
false
>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
>
;
using
BufferType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
,
true
>
;
using
BufferType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
,
true
>
;
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
...
@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
...
@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
});
});
});
});
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
b41e6019
...
@@ -456,8 +456,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -456,8 +456,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
,
block_2_ctile_map_
))
raw_lengths_m_n_k_o_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -508,8 +507,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -508,8 +507,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
))
arg
.
raw_lengths_m_n_k_o_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
...
@@ -628,8 +626,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -628,8 +626,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
);
arg
.
raw_lengths_m_n_k_o_
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
b41e6019
...
@@ -194,6 +194,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -194,6 +194,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
// FIXME: pad K
static_assert
(
!
matrix_padder
.
PadK
,
"KPadding is currently not supported"
);
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -209,92 +212,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -209,92 +212,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
const
auto
AK0
=
K
/
AK1
;
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
...
@@ -312,84 +241,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -312,84 +241,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
}
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
BK0
=
K
/
BK1
;
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
}
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...
@@ -408,47 +271,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -408,47 +271,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
b1_grid_desc_n_k
=
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
Gemm1KPerBlock
)
*
Gemm1KPerBlock
;
const
auto
N
Pad
=
N
-
NRaw
;
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
)
;
const
auto
K
Pad
=
K
-
KRaw
;
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
)
;
// TODO: implement finer-grained padding
const
auto
B1K0
=
K
/
B1K1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
Default
)
{
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
else
{
// pad both B1N and B1K
const
auto
B1K0
=
K
/
B1K1
;
const
auto
b1_grid_desc_n_k
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
}
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
...
@@ -662,7 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -662,7 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
matrix_padder
.
PadN
>
;
// Argument
// Argument
// FIXME: constness
// FIXME: constness
...
@@ -711,7 +547,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -711,7 +547,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
}
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
},
c_extent_lowest_
{
c_gs_ms_gemm1ns_lengths
.
back
()},
c_stride_lowest_
{
c_gs_ms_gemm1ns_strides
.
back
()}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -745,6 +584,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -745,6 +584,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
index_t
c_extent_lowest_
;
index_t
c_stride_lowest_
;
};
};
// Invoker
// Invoker
...
@@ -859,7 +703,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -859,7 +703,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
return
false
;
}
}
// TODO: Check A/B0/B1 length & stride and scalar per vector
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
arg
.
c_extent_lowest_
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if
(
arg
.
c_stride_lowest_
!=
1
)
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
@@ -996,7 +868,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -996,7 +868,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
b41e6019
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -198,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -198,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
// FIXME: pad K
static_assert
(
!
matrix_padder
.
PadK
,
"KPadding is currently not supported"
);
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -213,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -213,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
a_grid_desc_mraw_kraw
,
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
AK0
=
K
/
AK1
;
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
}
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
else
make_pass_through_transform
(
M
)),
{
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
// not pad M or K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
...
@@ -316,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -316,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
}
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
BK0
=
K
/
BK1
;
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
}
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...
@@ -412,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -412,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
b1_grid_desc_n_k
=
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
Gemm1KPerBlock
)
*
Gemm1KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
// TODO: implement finer-grained padding
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
Default
)
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
{
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
const
auto
B1K0
=
K
/
B1K1
;
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
return
transform_tensor_descriptor
(
}
b1_grid_desc_n_k
,
else
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
{
make_pass_through_transform
(
N
)),
// pad both B1N and B1K
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
const
auto
B1K0
=
K
/
B1K1
;
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_n_k
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
...
@@ -470,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -470,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
}
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
...
@@ -617,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -617,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
matrix_padder
.
PadN
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -661,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -661,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -694,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -694,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
};
};
// Invoker
// Invoker
...
@@ -797,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -797,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
false
;
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
...
@@ -913,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -913,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
b41e6019
...
@@ -3,19 +3,10 @@
...
@@ -3,19 +3,10 @@
#pragma once
#pragma once
#include <
iostream
>
#include <
memory
>
#include <
sstream
>
#include <
vector
>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -24,227 +15,54 @@ namespace device {
...
@@ -24,227 +15,54 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
index_t
Rank
,
typename
InElementwiseOp
,
index_t
NumReduceDim
,
typename
AccElementwiseOp
,
index_t
BlockSize
,
index_t
Rank
>
index_t
MThreadClusterSize
,
struct
DeviceSoftmax
:
public
BaseOperator
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceSoftmax
:
public
DeviceNormalization
{
{
static
constexpr
index_t
kRank
=
Rank
;
//
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
// @brief Makes a pointer to Argument class.
//
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
// value as type AccDataType
using
Reduction
=
DeviceReduceMultiBlock
<
InDataType
,
// @param[in] in_dev Typeless const pointer in device memory storing the input
AccDataType
,
// tensor
OutDataType
,
// @param out_dev Typeless pointer in device memory storing the output tensor
Rank
,
// @param[in] in_elementwise_op The input elementwise operation.
NumReduceDim
,
// @param[in] acc_elementwise_op The accumulation elementwise operation.
reduce
::
Add
,
//
PassThrough
,
// InElementwiseOperation
// @return Unique pointer to the Argument class.
PassThrough
,
// AccElementwiseOperation
//
InMemoryDataOperationEnum
::
Set
,
virtual
std
::
unique_ptr
<
BaseArgument
>
false
,
// PropagateNan
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
false
,
// OutputIndex
const
std
::
vector
<
index_t
>
inStrides
,
false
,
// HaveIndexInputIfOutputIndex
const
std
::
vector
<
int
>
reduceDims
,
BlockSize
,
const
void
*
alpha
,
MThreadClusterSize
,
const
void
*
beta
,
KThreadClusterSize
,
const
void
*
in_dev
,
MThreadSliceSize
,
void
*
out_dev
,
KThreadSliceSize
,
InElementwiseOp
in_elementwise_op
,
InSrcVectorDim
,
AccElementwiseOp
acc_elementwise_op
)
=
0
;
InSrcVectorSize
,
1
>
;
// OutDstVectorSize
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
index_t
GetRank
()
const
=
0
;
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
virtual
index_t
GetNumReduceDim
()
const
=
0
;
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
false
>
;
using
GridwiseSoftmaxSweepOnce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
alpha
,
AccDataType
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
)
:
Reduction
::
Argument
(
inLengths
,
inStrides
,
{},
{},
reduceDims
,
0.0
f
,
// alpha
0.0
f
,
// beta
in_dev
,
nullptr
,
out_dev
,
nullptr
,
PassThrough
{},
PassThrough
{}),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_
(
alpha
),
beta_
(
beta
)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType
alpha_
;
AccDataType
beta_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
bool
sweep_once
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_softmax
<
GridwiseSoftmaxSweepOnce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_softmax
<
GridwiseSoftmaxGeneric
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
beta_
,
arg
.
out_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
))
{
return
false
;
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
return
true
;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
),
*
static_cast
<
const
AccDataType
*>
(
beta
),
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceSoftmax<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
index_t
Rank
>
using
DeviceSoftmaxPtr
=
std
::
unique_ptr
<
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
0 → 100644
View file @
b41e6019
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
InElementwiseOp
,
typename
AccElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceSoftmaxImpl
:
public
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
InElementwiseOp
,
AccElementwiseOp
,
Rank
>
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
virtual
index_t
GetNumReduceDim
()
const
override
{
return
kNumReduceDim
;
}
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using
Reduction
=
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
reduce
::
Add
,
InElementwiseOp
,
AccElementwiseOp
,
InMemoryDataOperationEnum
::
Set
,
false
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
1
>
;
// OutDstVectorSize
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseSoftmaxGeneric
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
false
>
;
using
GridwiseSoftmaxSweepOnce
=
GridwiseSoftmax_mk_to_mk
<
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
alpha
,
AccDataType
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
:
Reduction
::
Argument
(
inLengths
,
inStrides
,
{},
{},
reduceDims
,
0.0
f
,
// alpha
0.0
f
,
// beta
in_dev
,
nullptr
,
out_dev
,
nullptr
,
in_elementwise_op
,
acc_elementwise_op
),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_
(
alpha
),
beta_
(
beta
)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType
alpha_
;
AccDataType
beta_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
bool
sweep_once
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
const
auto
kernel_main
=
sweep_once
?
kernel_softmax
<
GridwiseSoftmaxSweepOnce
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
:
kernel_softmax
<
GridwiseSoftmaxGeneric
,
InDataType
,
OutDataType
,
AccDataType
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in_grid_desc_m_k
,
out_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
alpha_
,
arg
.
in_dev_
,
arg
.
beta_
,
arg
.
out_dev_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
))
{
return
false
;
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
return
true
;
};
//
// @brief Makes a pointer to Argument class.
//
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation.
//
// @return Unique pointer to the Argument class.
//
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
),
*
static_cast
<
const
AccDataType
*>
(
beta
),
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceSoftmax<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
b41e6019
...
@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
std
::
vector
<
index_t
>&
lengths_m_n_k_o
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
return
false
;
}
}
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const
auto
KRaw
=
lengths_m_n_k_o
[
2
];
if
(
!
(
KRaw
%
AK1
==
0
&&
KRaw
%
BK1
==
0
))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
{
...
@@ -602,8 +594,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -602,8 +594,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size
constexpr
index_t
Gemm1KPack
=
math
::
max
(
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
math
::
gcd
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
b41e6019
...
@@ -75,7 +75,8 @@ template <typename FloatAB,
...
@@ -75,7 +75,8 @@ template <typename FloatAB,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
,
bool
PadN
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
@@ -330,6 +331,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -330,6 +331,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
};
template
<
bool
Pred
>
struct
ElementOpPredicatedResetNaNToMinusInf
;
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
true
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
if
(
ck
::
math
::
isnan
(
x
))
{
y
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
op
(
y
,
x
);
}
}
};
template
<
>
struct
ElementOpPredicatedResetNaNToMinusInf
<
false
>
{
template
<
typename
ElementOp
,
typename
OutT
,
typename
InT
>
__host__
__device__
void
Run
(
OutT
&
y
,
const
ElementOp
&
op
,
const
InT
&
x
)
{
op
(
y
,
x
);
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -348,14 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -348,14 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
p_a_grid
,
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
p_a_grid
,
NumericLimits
<
FloatAB
>::
QuietNaN
());
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
NumericLimits
<
FloatAB
>::
QuietNaN
()),
p_b_grid
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()));
NumericLimits
<
FloatAB
>::
QuietNaN
());
const
auto
b_grid_buf
=
conditional_expr
<
PadN
>
(
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
NumericLimits
<
FloatAB
>::
QuietNaN
()),
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()));
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -608,8 +645,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -608,8 +645,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size
constexpr
index_t
Gemm1KPack
=
math
::
max
(
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
math
::
gcd
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
...
@@ -680,7 +718,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -680,7 +718,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
threadid_to_m_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
@@ -721,8 +764,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -721,8 +764,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
num_k_block_main_loop
);
num_k_block_main_loop
);
// Acc0 elementwise Op
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
});
#endif
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
View file @
b41e6019
...
@@ -32,6 +32,20 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_i
...
@@ -32,6 +32,20 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B0Layout
,
typename
B1Layout
,
typename
B1Layout
,
...
@@ -82,6 +96,12 @@ struct DeviceOperationInstanceFactory<
...
@@ -82,6 +96,12 @@ struct DeviceOperationInstanceFactory<
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
op_ptrs
);
}
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp
0 → 100644
View file @
b41e6019
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
);
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
4
>>&
);
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>>
{
using
DeviceOp
=
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
std
::
is_same_v
<
InDataType
,
F16
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
3
)
add_device_softmax_f16_f16_rank3_instances
(
op_ptrs
);
else
if
constexpr
(
Rank
==
4
)
add_device_softmax_f16_f16_rank4_instances
(
op_ptrs
);
}
else
if
constexpr
(
std
::
is_same_v
<
InDataType
,
F32
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
3
)
add_device_softmax_f32_f32_rank3_instances
(
op_ptrs
);
else
if
constexpr
(
Rank
==
4
)
add_device_softmax_f32_f32_rank4_instances
(
op_ptrs
);
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
b41e6019
...
@@ -78,6 +78,7 @@ target_include_directories(device_operations PUBLIC
...
@@ -78,6 +78,7 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/problem_transform>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/problem_transform>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/device>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/device>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/device/impl>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/grid>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/grid>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/block>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/block>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/warp>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_operation/gpu/warp>
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
View file @
b41e6019
add_instance_library
(
device_batched_gemm_gemm_instance
add_instance_library
(
device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
b41e6019
...
@@ -35,11 +35,21 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
...
@@ -35,11 +35,21 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment