Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5aa3c344
Unverified
Commit
5aa3c344
authored
Oct 05, 2022
by
rocking5566
Committed by
GitHub
Oct 05, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
7fefc966
9d8f834a
Changes
129
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1034 additions
and
171 deletions
+1034
-171
include/ck/utility/span.hpp
include/ck/utility/span.hpp
+67
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+17
-21
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
...ry/reference_tensor_operation/cpu/reference_groupnorm.hpp
+191
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
...operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
+139
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
...e/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
+100
-0
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
...de/ck/library/tensor_operation_instance/gpu/layernorm.hpp
+36
-12
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+25
-13
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+12
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+45
-13
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+12
-54
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
...nstance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+80
-0
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
+81
-0
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
...ed_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+85
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+38
-20
library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
...ar_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
+45
-2
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp
...tance/gpu/normalization/device_layernorm_f16_instance.cpp
+25
-19
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp
...tance/gpu/normalization/device_layernorm_f32_instance.cpp
+23
-17
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+5
-0
No files found.
include/ck/utility/span.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
#include <array>
#include <type_traits>
namespace
ck
{
template
<
typename
T
>
class
span
{
public:
using
element_type
=
T
;
using
value_type
=
std
::
remove_cv_t
<
element_type
>
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
element_type
*
;
using
const_pointer
=
const
element_type
*
;
using
reference
=
element_type
&
;
using
const_reference
=
const
element_type
&
;
using
iterator
=
pointer
;
using
const_iterator
=
pointer
;
constexpr
span
()
:
span
(
nullptr
,
size_type
{
0
})
{}
constexpr
span
(
pointer
first
,
size_type
count
)
:
ptr_
(
first
),
size_
(
count
)
{}
constexpr
span
(
pointer
first
,
pointer
last
)
:
span
(
first
,
last
-
first
)
{}
template
<
std
::
size_t
N
>
constexpr
span
(
element_type
(
&
arr
)[
N
])
noexcept
:
span
(
arr
,
N
)
{
}
template
<
std
::
size_t
N
>
constexpr
span
(
std
::
array
<
value_type
,
N
>&
arr
)
noexcept
:
span
(
arr
.
data
(),
N
)
{
}
template
<
typename
Container
>
constexpr
span
(
const
Container
&
container
)
:
span
(
container
.
data
(),
container
.
size
())
{
}
constexpr
iterator
begin
()
const
noexcept
{
return
ptr_
;
}
constexpr
const_iterator
cbegin
()
const
noexcept
{
return
begin
();
}
constexpr
iterator
end
()
const
noexcept
{
return
begin
()
+
size
();
}
constexpr
const_iterator
cend
()
const
noexcept
{
return
end
();
}
constexpr
reference
front
()
const
{
return
*
begin
();
}
constexpr
reference
back
()
const
{
return
*
(
--
end
());
}
constexpr
reference
operator
[](
size_type
idx
)
const
{
return
*
(
begin
()
+
idx
);
}
constexpr
pointer
data
()
const
noexcept
{
return
ptr_
;
}
constexpr
size_type
size
()
const
noexcept
{
return
size_
;
}
private:
pointer
ptr_
;
size_type
size_
;
};
}
// namespace ck
include/ck/utility/transpose_vectors.hpp
View file @
5aa3c344
...
...
@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x0
),
"v"
(
x1
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1]
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x0
),
"v"
(
x1
));
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
#endif
}
...
...
@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
// clang-format off
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m0
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m0
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z0
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z1
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m3
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m3
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z2
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z3
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
// clang-format on
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
);
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m0
);
z0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
z1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m3
);
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m3
);
z2
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
z3
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
y0
=
bit_cast
<
int8x4_t
>
(
z0
);
y1
=
bit_cast
<
int8x4_t
>
(
z1
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
>
struct
ReferenceGroupnorm
:
public
device
::
BaseOperator
{
// x = [N, H, W, G, C]
// y = [N, H, W, G, C]
// reduce dim [H, W, C], mean, var = [N, G]
// gamma, beta = [G, C]
// beta: [G, C]
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
XDataType
>&
x
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
AccDataType
epsilon
)
:
x_
(
x
),
gamma_
(
gamma
),
beta_
(
beta
),
y_
(
y
),
acc_elementwise_op_
(
acc_elementwise_op
),
lengths_
(
lengths
),
epsilon_
(
epsilon
)
{
}
const
Tensor
<
XDataType
>
x_
;
const
Tensor
<
XDataType
>
gamma_
;
const
Tensor
<
XDataType
>
beta_
;
Tensor
<
YDataType
>&
y_
;
AccElementwiseOperation
acc_elementwise_op_
;
std
::
vector
<
index_t
>
lengths_
;
AccDataType
epsilon_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
int
N
=
arg
.
lengths_
[
0
];
int
H
=
arg
.
lengths_
[
1
];
int
W
=
arg
.
lengths_
[
2
];
int
G
=
arg
.
lengths_
[
3
];
int
C
=
arg
.
lengths_
[
4
];
Tensor
<
AccDataType
>
mean
({
N
,
G
});
Tensor
<
AccDataType
>
var
({
N
,
G
});
// Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC
// TODO - address calculation
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
delta
=
x
-
mean_val
;
mean_val
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean_val
;
var_val
+=
delta
*
delta2
;
}
}
}
mean
(
n
,
g
)
=
mean_val
;
var
(
n
,
g
)
=
var_val
/
curr_count
;
}
}
// Normalization
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
g
=
0
;
g
<
G
;
++
g
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
x_
(
n
,
h
,
w
,
g
,
c
));
AccDataType
gamma
=
type_convert
<
AccDataType
>
(
arg
.
gamma_
(
g
,
c
));
AccDataType
beta
=
type_convert
<
AccDataType
>
(
arg
.
beta_
(
g
,
c
));
AccDataType
mean_val
=
type_convert
<
AccDataType
>
(
mean
(
n
,
g
));
AccDataType
var_val
=
type_convert
<
AccDataType
>
(
var
(
n
,
g
));
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
arg
.
acc_elementwise_op_
(
y
,
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
}
}
}
}
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
->
lengths_
.
size
()
!=
5
)
return
false
;
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
XDataType
>&
x
,
const
Tensor
<
GammaDataType
>&
gamma
,
const
Tensor
<
BetaDataType
>&
beta
,
Tensor
<
YDataType
>&
y
,
AccElementwiseOperation
acc_elementwise_op
,
const
std
::
vector
<
index_t
>
lengths
,
AccDataType
epsilon
)
{
return
Argument
{
x
,
gamma
,
beta
,
y
,
acc_elementwise_op
,
lengths
,
epsilon
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceLayernorm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
);
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
);
template
<
typename
A0Layout
,
typename
B0Layout
,
typename
D0sLayout
,
typename
B1Layout
,
typename
D1sLayout
,
typename
E1Layout
,
typename
A0DataType
,
typename
B0DataType
,
typename
D0sDataType
,
typename
B1DataType
,
typename
D1sDataType
,
typename
E1DataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmMultipleDGemmMultipleD
<
A0Layout
,
B0Layout
,
D0sLayout
,
B1Layout
,
D1sLayout
,
E1Layout
,
A0DataType
,
B0DataType
,
D0sDataType
,
B1DataType
,
D1sDataType
,
E1DataType
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>
{
using
DeviceOp
=
DeviceBatchedGemmMultipleDGemmMultipleD
<
A0Layout
,
B0Layout
,
D0sLayout
,
B1Layout
,
D1sLayout
,
E1Layout
,
A0DataType
,
B0DataType
,
D0sDataType
,
B1DataType
,
D1sDataType
,
E1DataType
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
A0DataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
E1DataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
A0Layout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
E1Layout
,
Row
>
)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
A0Layout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Col
>
&&
is_same_v
<
E1Layout
,
Row
>
)
{
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
View file @
5aa3c344
...
...
@@ -17,17 +17,25 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_layernorm_f16_rank2_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>&
);
// FP16
void
add_device_layernorm_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>>&
);
void
add_device_layernorm_
f16_rank4
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
4
,
3
>>&
);
void
add_device_layernorm_
rank_4_3_f16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
4
,
3
>>
>
&
);
void
add_device_layernorm_
f32_rank2
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F32
,
F
32
,
F
32
,
F32
,
F
32
,
PassThrough
,
2
,
1
>>&
);
void
add_device_layernorm_
rank_5_3_f16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F
16
,
F
16
,
F32
,
F
16
,
PassThrough
,
5
,
3
>
>>&
);
void
add_device_layernorm_f32_rank4_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>&
);
// FP32
void
add_device_layernorm_rank_2_1_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>>&
);
void
add_device_layernorm_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
void
add_device_layernorm_rank_5_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
typename
GammaDataType
,
...
...
@@ -62,17 +70,33 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f16_rank2_instances
(
op_ptrs
);
{
add_device_layernorm_rank_2_1_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f16_rank4_instances
(
op_ptrs
);
{
add_device_layernorm_rank_4_3_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_layernorm_rank_5_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f32_rank2_instances
(
op_ptrs
);
{
add_device_layernorm_rank_2_1_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f32_rank4_instances
(
op_ptrs
);
{
add_device_layernorm_rank_4_3_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_layernorm_rank_5_3_f32_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
5aa3c344
...
...
@@ -15,6 +15,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
...
...
@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
}
res
=
false
;
...
...
@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
...
...
@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
...
...
@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>
::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>
&
out
,
const
std
::
vector
<
T
>
&
ref
,
typename
std
::
enable_if
<
std
::
is_same
_v
<
T
,
half_t
>
,
bool
>::
type
check_err
(
span
<
const
T
>
out
,
span
<
const
T
>
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
...
...
@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
if
(
err_count
<
5
)
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
...
...
@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
return
check_err
(
span
<
const
T
>
{
out
},
span
<
const
T
>
{
ref
},
msg
,
rtol
,
atol
);
}
template
<
typename
T
>
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...
...
@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
}
if
(
!
res
)
{
std
::
c
out
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
...
...
library/include/ck/library/utility/fill.hpp
View file @
5aa3c344
...
...
@@ -5,7 +5,10 @@
#include <algorithm>
#include <cmath>
#include <iterator>
#include <random>
#include <type_traits>
#include <utility>
#include "ck/utility/data_type.hpp"
...
...
@@ -25,6 +28,15 @@ struct FillUniformDistribution
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution
>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
5aa3c344
...
...
@@ -3,15 +3,16 @@
#pragma once
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm>
#include <utility>
#include <cassert>
#include <iostream>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
...
@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
template
<
typename
T
>
struct
Tensor
{
using
Descriptor
=
HostTensorDescriptor
;
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
...
...
@@ -251,7 +255,7 @@ struct Tensor
{
}
Tensor
(
const
HostTensor
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
...
...
@@ -278,9 +282,9 @@ struct Tensor
{
}
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
decltype
(
auto
)
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
decltype
(
auto
)
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
std
::
size_t
GetNumOfDimension
()
const
{
return
mDesc
.
GetNumOfDimension
();
}
...
...
@@ -288,6 +292,8 @@ struct Tensor
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
{
for
(
auto
&
v
:
mData
)
...
...
@@ -425,14 +431,40 @@ struct Tensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
typename
std
::
vector
<
T
>::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
std
::
vector
<
T
>::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
pointer
data
()
{
return
mData
.
data
();
}
typename
std
::
vector
<
T
>
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
Data
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
std
::
vector
<
T
>::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_pointer
data
()
const
{
return
mData
.
data
();
}
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
add_const_t
<
std
::
remove_reference_t
<
U
>>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
template
<
typename
U
=
T
>
auto
AsSpan
()
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
remove_reference_t
<
U
>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
HostTensor
Descriptor
mDesc
;
std
::
vector
<
T
>
mData
;
Descriptor
mDesc
;
Data
mData
;
};
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
5aa3c344
...
...
@@ -6,61 +6,19 @@ function(add_instance_library INSTANCE_NAME)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
endfunction
(
add_instance_library INSTANCE_NAME
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_splitk
)
add_subdirectory
(
gemm_bilinear
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
contraction_scale
)
add_subdirectory
(
contraction_bilinear
)
add_subdirectory
(
grouped_conv1d_fwd
)
add_subdirectory
(
grouped_conv2d_fwd
)
add_subdirectory
(
grouped_conv3d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv1d_bwd_data
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
conv3d_bwd_data
)
add_subdirectory
(
conv1d_bwd_weight
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
conv3d_bwd_weight
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
reduce
)
add_subdirectory
(
normalization
)
add_subdirectory
(
elementwise
)
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
set
(
CK_DEVICE_INSTANCES
)
FOREACH
(
subdir_path
${
dir_list
}
)
set
(
target_dir
)
IF
(
IS_DIRECTORY
"
${
subdir_path
}
"
)
get_filename_component
(
target_dir
${
subdir_path
}
NAME
)
add_subdirectory
(
${
target_dir
}
)
list
(
APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
ENDIF
()
ENDFOREACH
()
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bilinear_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance>
$<TARGET_OBJECTS:device_contraction_bilinear_instance>
$<TARGET_OBJECTS:device_grouped_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_grouped_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_grouped_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv1d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv3d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv1d_bwd_weight_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_conv3d_bwd_weight_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
)
add_library
(
device_operations STATIC
${
CK_DEVICE_INSTANCES
}
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt
0 → 100644
View file @
5aa3c344
add_instance_library
(
device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
=
std
::
tuple
<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
false
,
false
,
false
,
false
,
false
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
F32
,
F32
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
,
true
,
true
,
true
,
true
,
true
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultipleDGemmMultipleD
<
Row
,
Col
,
ck
::
Tuple
<
Row
>
,
Col
,
ck
::
Tuple
<
Row
>
,
Row
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
ck
::
Tuple
<
F16
>
,
F16
,
PassThrough
,
PassThrough
,
CDE0ElementOp
,
PassThrough
,
CDE1ElementOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
0 → 100644
View file @
5aa3c344
add_instance_library
(
device_batched_gemm_masking_scale_softmax_gemm_permute_instance
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// 2 of them are commented out because they trigger the clang-13 issue.
//##############################################| ALayout| B0Layout| B1Layout| CPermuteNumDims_G_M_O| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
//##############################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
//##############################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
//##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
//DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
//DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
true
>
// clang-format on
>
;
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
5aa3c344
...
...
@@ -26,32 +26,47 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNOPadding
;
// Padding K is currently flawed
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
MaskOut|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Upper|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
Triangle|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
// clang-format on
>
;
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances
=
std
::
tuple
<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
false
>
// clang-format on
>
;
...
...
@@ -73,6 +88,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
add_device_operation_instances
(
instances
,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
add_device_operation_instances
(
instances
,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
View file @
5aa3c344
...
...
@@ -37,6 +37,7 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
=
std
::
tuple
<
// clang-format off
// no padding
// N % 8 == 0 && K % 8 == 0
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...
...
@@ -55,7 +56,8 @@ using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances =
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
// M/N/N padding
// M/N/K padding
// N % 8 == 0 && K % 8 == 0
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...
...
@@ -72,7 +74,48 @@ using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances =
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
// M/N/K padding
// N % 4 == 0 && K % 4 == 0
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
// M/N/K padding
// N % 8 == 0 && K % 1 == 0
//##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
32
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
32
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
4
,
1
,
32
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
64
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
32
>
,
1
>
,
DeviceGemmMultipleD_Xdl_CShuffle
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
32
>
,
1
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp
View file @
5aa3c344
...
...
@@ -17,34 +17,40 @@ using F32 = float;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
index_t
Rank
,
index_t
Reduce
>
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_layernorm_f16_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVector
Size
, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
2
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
,
8
,
8
>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVector
Dim, GammaSrcVectorSize, BetaSrcVectorDim
, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
// clang-format on
>
;
void
add_device_layernorm_
f16_rank2
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
2
,
1
>>&
instances
)
void
add_device_layernorm_
rank_2_1_f16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
2
,
1
>>
>
&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f16_instances
<
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_layernorm_f16_instances
<
Pass
,
2
,
1
>
{});
}
void
add_device_layernorm_
f16_rank4
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
4
,
3
>>&
instances
)
void
add_device_layernorm_
rank_4_3_f16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
4
,
3
>>
>
&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f16_instances
<
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_layernorm_f16_instances
<
Pass
,
4
,
3
>
{});
}
void
add_device_layernorm_rank_5_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
5
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f16_instances
<
Pass
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp
View file @
5aa3c344
...
...
@@ -16,33 +16,39 @@ using F32 = float;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
index_t
Rank
,
index_t
Reduce
>
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_layernorm_f32_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
2
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
,
4
,
4
>
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
// clang-format on
>
;
void
add_device_layernorm_
f32_
rank2_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
2
,
1
>>&
instances
)
void
add_device_layernorm_rank
_2_1_f3
2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
2
,
1
>>
>
&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f32_instances
<
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_layernorm_f32_instances
<
Pass
,
2
,
1
>
{});
}
void
add_device_layernorm_
f32_rank4
_instances
(
std
::
vector
<
DeviceLayernorm
Ptr
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
4
,
3
>>&
instances
)
void
add_device_layernorm_
rank_4_3_f32
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
4
,
3
>>
>
&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f32_instances
<
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_layernorm_f32_instances
<
Pass
,
4
,
3
>
{});
}
void
add_device_layernorm_rank_5_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceLayernorm
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
5
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_layernorm_f32_instances
<
Pass
,
5
,
3
>
{});
}
}
// namespace instance
...
...
profiler/CMakeLists.txt
View file @
5aa3c344
...
...
@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp
...
...
@@ -21,6 +23,7 @@ set(PROFILER_SOURCE
src/profile_conv_bwd_weight.cpp
src/profile_grouped_conv_fwd.cpp
src/profile_reduce.cpp
src/profile_groupnorm.cpp
src/profile_layernorm.cpp
src/profile_normalization.cpp
)
...
...
@@ -35,6 +38,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment