Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a72a5762
Commit
a72a5762
authored
Feb 24, 2023
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into tile
parents
b00ae5df
209baee2
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4863 additions
and
31 deletions
+4863
-31
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
+4
-0
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+431
-0
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+3
-0
example/30_grouped_conv_fwd_multiple_d/common.hpp
example/30_grouped_conv_fwd_multiple_d/common.hpp
+1
-1
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
+355
-0
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
...d_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
+26
-0
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+286
-0
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
+7
-7
example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc
...le/46_gemm_add_multiply/run_gemm_add_multiply_example.inc
+2
-3
include/ck/ck.hpp
include/ck/ck.hpp
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+5
-4
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+991
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
...gpu/device/impl/device_elementwise_normalization_impl.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+14
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+654
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+1216
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+850
-0
No files found.
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
View file @
a72a5762
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
endif
()
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
DDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F16
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
ABSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Packed
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceOpInstanceKKNN
=
ck
::
tensor_operation
::
device
::
DeviceBatchedContractionMultipleD_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
ABSpec
,
ABSpec
,
DESpec
,
256
,
128
,
256
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template
<
ck
::
index_t
NumDimG
,
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
enable_if_t
<
NumDimG
==
2
&&
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
1
,
bool
>
=
false
>
struct
ReferenceContraction_G2_M2_N2_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_gs_ms_ks_
{
a_gs_ms_ks
},
b_gs_ns_ks_
{
b_gs_ns_ks
},
e_gs_ms_ns_
{
e_gs_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_gs_ms_ks_
;
const
Tensor
<
BDataType
>&
b_gs_ns_ks_
;
Tensor
<
EDataType
>&
e_gs_ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
{
using
Argument
=
ReferenceContraction_G2_M2_N2_K1
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
g0
,
auto
g1
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
for
(
int
k0
=
0
;
k0
<
K0
;
++
k0
)
{
AccDataType
v_a
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_gs_ms_ks_
(
g0
,
g1
,
m0
,
m1
,
k0
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_gs_ns_ks_
(
g0
,
g1
,
n0
,
n1
,
k0
)));
v_acc
+=
v_a
*
v_b
;
}
AccDataType
v_c
;
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
e_gs_ms_ns_
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
ck
::
tensor_operation
::
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_gs_ms_ks
,
const
Tensor
<
BDataType
>&
b_gs_ns_ks
,
Tensor
<
EDataType
>&
e_gs_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_gs_ms_ks
,
b_gs_ns_ks
,
e_gs_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
ck
::
tensor_operation
::
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceContraction_G2_M2_N2_K1"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
true
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
128
;
ck
::
index_t
N0
=
16
;
ck
::
index_t
N1
=
256
;
ck
::
index_t
K0
=
2048
;
// A[G0, G1, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
G1
*
M0
*
M1
*
K0
,
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G0, G1, N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_lengths
{
G0
,
G1
,
N0
,
N1
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_strides
{
G1
*
N0
*
N1
*
K0
,
N0
*
N1
*
K0
,
N1
*
K0
,
K0
,
1
};
// D[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
{
G1
*
N0
*
N1
,
N0
*
N1
,
0
,
0
,
N1
,
1
};
// E[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
G1
*
M0
*
N0
*
M1
*
N1
,
M0
*
N0
*
M1
*
N1
,
N0
*
M1
*
N1
,
N1
,
M1
*
N1
,
1
};
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
5
,
5
});
break
;
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
// set zero
e_device_buf
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
G
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
NumDimG
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
NumDimM
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
N
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimN
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
cout
<<
"GMNK="
<<
G
<<
", "
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
std
::
endl
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_G2_M2_N2_K1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
b_gs_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
g1
=
0
;
g1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
g1
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n1
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
),
d_gs_ms_ns
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
));
}
}
}
}
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
}
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
a72a5762
...
...
@@ -16,6 +16,9 @@ if(USE_BITINT_EXTENSION_INT4)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4
)
endif
()
# USE_BITINT_EXTENSION_INT4
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
)
endif
()
add_example_executable
(
example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp
)
...
...
example/30_grouped_conv_fwd_multiple_d/common.hpp
View file @
a72a5762
...
...
@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc,
const
ck
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_param
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
threshold_to_catch_partial_args
,
argv
);
num_dim_spatial
,
threshold_to_catch_partial_args
+
1
,
argv
);
}
else
{
...
...
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <array>
#include <iostream>
#include <string>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using
BF16
=
ck
::
bhalf_t
;
using
FP16
=
ck
::
half_t
;
using
FP32
=
float
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
I4
=
ck
::
int4_t
;
#endif
using
I8
=
std
::
int8_t
;
using
I32
=
std
::
int32_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
typename
InputLay
,
typename
WeightLay
,
typename
OutputLay
>
struct
CommonLayoutSetting
{
using
InputLayout
=
InputLay
;
using
WeightLayout
=
WeightLay
;
using
OutputLayout
=
OutputLay
;
};
template
<
ck
::
index_t
NDimSpatial
>
struct
CommonLayoutSettingSelector
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
template
<
>
struct
CommonLayoutSettingSelector
<
1
>
final
:
CommonLayoutSetting
<
ctl
::
G_NW_C
,
ctl
::
G_K_X_C
,
ctl
::
G_NW_K
>
{
};
template
<
>
struct
CommonLayoutSettingSelector
<
2
>
final
:
CommonLayoutSetting
<
ctl
::
G_NHW_C
,
ctl
::
G_K_YX_C
,
ctl
::
G_NHW_K
>
{
};
template
<
>
struct
CommonLayoutSettingSelector
<
3
>
final
:
CommonLayoutSetting
<
ctl
::
G_NDHW_C
,
ctl
::
G_K_ZYX_C
,
ctl
::
G_NDHW_K
>
{
};
template
<
ck
::
index_t
NDimSpatial
>
using
InputLayout
=
typename
CommonLayoutSettingSelector
<
NDimSpatial
>::
InputLayout
;
template
<
ck
::
index_t
NDimSpatial
>
using
WeightLayout
=
typename
CommonLayoutSettingSelector
<
NDimSpatial
>::
WeightLayout
;
template
<
ck
::
index_t
NDimSpatial
>
using
OutputLayout
=
typename
CommonLayoutSettingSelector
<
NDimSpatial
>::
OutputLayout
;
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
true
;
};
#define DefaultConvParam \
ck::utils::conv::ConvParam \
{ \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
inline
void
print_help_msg
()
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
constexpr
int
num_execution_config_args
=
3
;
// arguments for do_verification, init_method, time_kernel
constexpr
int
num_conv_param_leading_args
=
5
;
// arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr
int
threshold_to_catch_partial_args
=
1
+
num_execution_config_args
;
constexpr
int
threshold_to_catch_all_args
=
threshold_to_catch_partial_args
+
num_conv_param_leading_args
;
if
(
argc
==
1
)
{
// use default
}
// catch only ExecutionConfig arguments
else
if
(
argc
==
threshold_to_catch_partial_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
// catch both ExecutionConfig & ConvParam arguments
else
if
(
threshold_to_catch_all_args
<
argc
&&
((
argc
-
threshold_to_catch_all_args
)
%
3
==
0
))
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_param
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
threshold_to_catch_partial_args
+
1
,
argv
);
}
else
{
print_help_msg
();
return
false
;
}
return
true
;
}
inline
HostTensorDescriptor
make_input_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
C_
,
conv_param
.
input_spatial_lengths_
[
0
]},
{
conv_param
.
C_
,
// g
conv_param
.
input_spatial_lengths_
[
0
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// n
1
,
// c
conv_param
.
G_
*
conv_param
.
C_
// wi
});
case
2
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
C_
,
conv_param
.
input_spatial_lengths_
[
0
],
conv_param
.
input_spatial_lengths_
[
1
]},
{
conv_param
.
C_
,
// g
conv_param
.
input_spatial_lengths_
[
0
]
*
conv_param
.
input_spatial_lengths_
[
1
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// n
1
,
// c
conv_param
.
input_spatial_lengths_
[
1
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// hi
conv_param
.
G_
*
conv_param
.
C_
// wi
});
case
3
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
C_
,
conv_param
.
input_spatial_lengths_
[
0
],
conv_param
.
input_spatial_lengths_
[
1
],
conv_param
.
input_spatial_lengths_
[
2
]},
{
conv_param
.
C_
,
// g
conv_param
.
input_spatial_lengths_
[
0
]
*
conv_param
.
input_spatial_lengths_
[
1
]
*
conv_param
.
input_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// n
1
,
// c
conv_param
.
input_spatial_lengths_
[
1
]
*
conv_param
.
input_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// di
conv_param
.
input_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
C_
,
// hi
conv_param
.
G_
*
conv_param
.
C_
// wi
});
}
throw
std
::
runtime_error
(
"unsuppored # dim spatial"
);
}
inline
HostTensorDescriptor
make_weight_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
K_
,
conv_param
.
C_
,
conv_param
.
filter_spatial_lengths_
[
0
]},
{
conv_param
.
K_
*
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
C_
,
// g
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
C_
,
// k
1
,
// c
conv_param
.
C_
// x
});
case
2
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
K_
,
conv_param
.
C_
,
conv_param
.
filter_spatial_lengths_
[
0
],
conv_param
.
filter_spatial_lengths_
[
1
]},
{
conv_param
.
K_
*
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
C_
,
// g
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
C_
,
// k
1
,
// c
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
C_
,
// y
conv_param
.
C_
// x
});
case
3
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
K_
,
conv_param
.
C_
,
conv_param
.
filter_spatial_lengths_
[
0
],
conv_param
.
filter_spatial_lengths_
[
1
],
conv_param
.
filter_spatial_lengths_
[
2
]},
{
conv_param
.
K_
*
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
filter_spatial_lengths_
[
2
]
*
conv_param
.
C_
,
// g
conv_param
.
filter_spatial_lengths_
[
0
]
*
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
filter_spatial_lengths_
[
2
]
*
conv_param
.
C_
,
// k
1
,
// c
conv_param
.
filter_spatial_lengths_
[
1
]
*
conv_param
.
filter_spatial_lengths_
[
2
]
*
conv_param
.
C_
,
// z
conv_param
.
filter_spatial_lengths_
[
2
]
*
conv_param
.
C_
,
// y
conv_param
.
C_
// x
});
}
throw
std
::
runtime_error
(
"unsuppored # dim spatial"
);
}
inline
HostTensorDescriptor
make_bias_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
]},
{
conv_param
.
K_
,
// g
0
,
// k
1
,
// c
0
// x
});
case
2
:
return
HostTensorDescriptor
({
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
],
conv_param
.
output_spatial_lengths_
[
1
]},
{
conv_param
.
K_
,
// g
0
,
// n
1
,
// k
0
,
// ho
0
// wo
});
case
3
:
return
HostTensorDescriptor
({
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
],
conv_param
.
output_spatial_lengths_
[
1
],
conv_param
.
output_spatial_lengths_
[
2
]},
{
conv_param
.
K_
,
// g
0
,
// n
1
,
// k
0
,
// z
0
,
// y
0
// x
});
}
throw
std
::
runtime_error
(
"unsuppored # dim spatial"
);
}
inline
HostTensorDescriptor
make_output_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
]},
{
conv_param
.
K_
,
// g
conv_param
.
output_spatial_lengths_
[
0
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// n
1
,
// k
conv_param
.
G_
*
conv_param
.
K_
// wo
});
case
2
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
],
conv_param
.
output_spatial_lengths_
[
1
]},
{
conv_param
.
K_
,
// g
conv_param
.
output_spatial_lengths_
[
0
]
*
conv_param
.
output_spatial_lengths_
[
1
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// n
1
,
// k
conv_param
.
output_spatial_lengths_
[
1
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// ho
conv_param
.
G_
*
conv_param
.
K_
// wo
});
case
3
:
return
HostTensorDescriptor
(
{
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
output_spatial_lengths_
[
0
],
conv_param
.
output_spatial_lengths_
[
1
],
conv_param
.
output_spatial_lengths_
[
2
]},
{
conv_param
.
K_
,
// g
conv_param
.
output_spatial_lengths_
[
0
]
*
conv_param
.
output_spatial_lengths_
[
1
]
*
conv_param
.
output_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// n
1
,
// k
conv_param
.
output_spatial_lengths_
[
1
]
*
conv_param
.
output_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// do
conv_param
.
output_spatial_lengths_
[
2
]
*
conv_param
.
G_
*
conv_param
.
K_
,
// ho
conv_param
.
G_
*
conv_param
.
K_
// wo
});
}
throw
std
::
runtime_error
(
"unsuppored # dim spatial"
);
}
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
// kernel data types
using
InKernelDataType
=
FP16
;
using
WeiKernelDataType
=
FP16
;
using
AccDataType
=
FP32
;
using
CShuffleDataType
=
FP16
;
using
BiasKernelDataType
=
FP16
;
using
ResidualKernelDataType
=
FP16
;
using
OutKernelDataType
=
FP16
;
// tensor data types
using
InUserDataType
=
InKernelDataType
;
using
WeiUserDataType
=
WeiKernelDataType
;
using
OutUserDataType
=
OutKernelDataType
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
template
<
typename
BiasLay
,
typename
ResidualLay
>
struct
LayoutSetting
{
using
BiasLayout
=
BiasLay
;
using
ResidualLayout
=
ResidualLay
;
};
template
<
ck
::
index_t
NDimSpatial
>
struct
LayoutSettingSelector
;
template
<>
struct
LayoutSettingSelector
<
1
>
final
:
LayoutSetting
<
ctl
::
G_K
,
ctl
::
G_NW_K
>
{
};
template
<>
struct
LayoutSettingSelector
<
2
>
final
:
LayoutSetting
<
ctl
::
G_K
,
ctl
::
G_NHW_K
>
{
};
template
<>
struct
LayoutSettingSelector
<
3
>
final
:
LayoutSetting
<
ctl
::
G_K
,
ctl
::
G_NDHW_K
>
{
};
template
<
ck
::
index_t
NDimSpatial
>
using
BiasLayout
=
typename
LayoutSettingSelector
<
NDimSpatial
>::
BiasLayout
;
template
<
ck
::
index_t
NDimSpatial
>
using
ResidualLayout
=
typename
LayoutSettingSelector
<
NDimSpatial
>::
ResidualLayout
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<
NDimSpatial
,
InputLayout
<
NDimSpatial
>
,
WeightLayout
<
NDimSpatial
>
,
ck
::
Tuple
<
BiasLayout
<
NDimSpatial
>
,
ResidualLayout
<
NDimSpatial
>>
,
OutputLayout
<
NDimSpatial
>
,
InKernelDataType
,
WeiKernelDataType
,
ck
::
Tuple
<
BiasKernelDataType
,
ResidualKernelDataType
>
,
OutKernelDataType
,
AccDataType
,
CShuffleDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
16
,
// MPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
2
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
true
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
true
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
InUserDataType
,
WeiUserDataType
,
CShuffleDataType
,
InElementOp
,
WeiElementOp
,
PassThrough
>
;
template
<
ck
::
index_t
NDimSpatial
>
bool
run_grouped_conv_fwd_bias_relu_add
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
static_assert
(
1
<=
NDimSpatial
&&
NDimSpatial
<=
3
,
"Unsupported NDimSpatial"
);
const
auto
in_g_n_c_wis_desc
=
make_input_descriptor
(
conv_param
);
const
auto
wei_g_k_c_xs_desc
=
make_weight_descriptor
(
conv_param
);
const
auto
bias_g_n_k_wos_desc
=
make_bias_descriptor
(
conv_param
);
const
auto
out_g_n_k_wos_desc
=
make_output_descriptor
(
conv_param
);
Tensor
<
InUserDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiUserDataType
>
wei
(
wei_g_k_c_xs_desc
);
Tensor
<
OutUserDataType
>
bias
(
bias_g_n_k_wos_desc
);
Tensor
<
OutUserDataType
>
residual
(
bias_g_n_k_wos_desc
);
Tensor
<
OutUserDataType
>
out_host
(
out_g_n_k_wos_desc
);
Tensor
<
OutKernelDataType
>
out_device
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias: "
<<
bias
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"residual: "
<<
residual
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
mDesc
<<
std
::
endl
;
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InUserDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiUserDataType
>
{
-
5
,
5
});
bias
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutUserDataType
>
{
-
5
,
5
});
break
;
default
:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InUserDataType
>
{
0.0
,
1.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiUserDataType
>
{
-
0.5
,
0.5
});
bias
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutUserDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
in_device_buf
(
sizeof
(
InKernelDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiKernelDataType
)
*
wei
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
bias_device_buf
(
sizeof
(
OutKernelDataType
)
*
bias
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
residual_device_buf
(
sizeof
(
OutKernelDataType
)
*
residual
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_device_buf
(
sizeof
(
OutKernelDataType
)
*
out_device
.
mDesc
.
GetElementSpaceSize
());
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
InKernelDataType
>
in_converted
(
in
);
const
Tensor
<
WeiKernelDataType
>
wei_converted
(
wei
);
const
Tensor
<
OutKernelDataType
>
bias_converted
(
bias
);
const
Tensor
<
OutKernelDataType
>
residual_converted
(
residual
);
in_device_buf
.
ToDevice
(
in_converted
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_converted
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_converted
.
mData
.
data
());
residual_device_buf
.
ToDevice
(
residual_converted
.
mData
.
data
());
#else
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
residual_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
#endif
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d0_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d0_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d1_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d1_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
a_g_n_c_wis_lengths
);
copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
a_g_n_c_wis_strides
);
copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
b_g_k_c_xs_lengths
);
copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
b_g_k_c_xs_strides
);
copy
(
bias_g_n_k_wos_desc
.
GetLengths
(),
d0_g_n_k_wos_lengths
);
copy
(
bias_g_n_k_wos_desc
.
GetStrides
(),
d0_g_n_k_wos_strides
);
copy
(
bias_g_n_k_wos_desc
.
GetLengths
(),
d1_g_n_k_wos_lengths
);
copy
(
bias_g_n_k_wos_desc
.
GetStrides
(),
d1_g_n_k_wos_strides
);
copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
e_g_n_k_wos_lengths
);
copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
e_g_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_right_pads_
,
input_right_pads
);
// do Conv
auto
conv
=
DeviceConvFwdInstance
<
NDimSpatial
>
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
bias_device_buf
.
GetDeviceBuffer
(),
residual_device_buf
.
GetDeviceBuffer
()},
out_device_buf
.
GetDeviceBuffer
(),
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
2
>
{
{
d0_g_n_k_wos_lengths
,
d1_g_n_k_wos_lengths
}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
2
>
{
{
d0_g_n_k_wos_strides
,
d1_g_n_k_wos_strides
}},
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
);
}
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
conv_param
.
GetFlops
();
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InUserDataType
,
WeiUserDataType
,
OutUserDataType
>
();
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
conv
.
GetTypeString
()
<<
std
::
endl
;
if
(
config
.
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_host
(
out_g_n_k_wos_desc
);
auto
ref_conv
=
HostConvFwdInstance
<
NDimSpatial
>
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
wei
,
c_host
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
// TODO: implement elementwise operation for host
out_host
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
OutElementOp
{}(
out_host
(
idx
),
c_host
(
idx
),
bias
(
idx
),
residual
(
idx
));
});
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
return
ck
::
utils
::
check_err
(
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#else
return
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#endif
}
return
true
;
}
bool
run_grouped_conv_fwd_bias_relu_add_example
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
false
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
run_grouped_conv_fwd_bias_relu_add
<
1
>
(
config
,
conv_param
);
case
2
:
return
run_grouped_conv_fwd_bias_relu_add
<
2
>
(
config
,
conv_param
);
case
3
:
return
run_grouped_conv_fwd_bias_relu_add
<
3
>
(
config
,
conv_param
);
}
return
false
;
}
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
View file @
a72a5762
...
...
@@ -23,11 +23,11 @@
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
Acc
DataType
=
float
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
Compute
DataType
=
float
;
struct
YElementOp
{
...
...
@@ -50,7 +50,7 @@ using DeviceInstance =
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
YElementOp
,
Rank
,
...
...
@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
GammaDataType
,
BetaDataType
,
YDataType
,
Acc
DataType
,
Compute
DataType
,
YElementOp
>
;
ReferenceInstance
ref
;
...
...
example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc
View file @
a72a5762
...
...
@@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
...
...
@@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
"wrong! this device_op instance does not support this problem"
<<
std
::
endl
;
return
true
;
std
::
cout
<<
"wrong! this device_op instance does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
...
...
include/ck/ck.hpp
View file @
a72a5762
...
...
@@ -168,13 +168,6 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
#else // __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
#endif // __gfx908__
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
a72a5762
...
...
@@ -20,6 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
...
...
@@ -29,15 +30,15 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
const
int
nrepeat
=
10
;
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
const
int
nrepeat
=
10
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
a72a5762
...
...
@@ -14,9 +14,9 @@ namespace device {
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
YDataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalization
:
public
BaseOperator
...
...
@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
void
*
p_y
,
void
*
p_savedMean
,
void
*
p_savedInvVar
,
Acc
ElementwiseOperation
acc
_elementwise_op
)
=
0
;
Y
ElementwiseOperation
y
_elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
YDataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
Acc
ElementwiseOperation
,
Y
ElementwiseOperation
,
Rank
,
NumReduceDim
>>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
// stride_i = stride_{i - 1} * extent_{i - 1}
// So tensor
// [G0, G1, G2, M, N]
// transposed into tensor
// [G0, G2, G1, M, N]
// with strides
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedContractionMultipleD_Wmma_CShuffle
:
public
DeviceBatchedContractionMultipleD
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedContractionMultipleD_Wmma_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
a_gs_ms_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_gs_ms_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_gs_ms_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
b_gs_ns_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_gs_ns_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_gs_ns_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
e_ms_ns_strides
[
Number
<
NumDimM
-
1
>
{}],
e_ms_ns_strides
[
Number
<
NumDimM
+
NumDimN
-
1
>
{}]));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_gs_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_gs_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
e_gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
else
{
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
e_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
static
auto
MakeDsGridDescriptor_G_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
DsGridDesc_G_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_G_M_N
({},
{}))
>
;
using
EGridDesc_G_M_N
=
decltype
(
MakeEGridDescriptor_G_M_N
({},
{}));
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
batch_stride_A
,
index_t
batch_stride_B
,
DsGridDesc_G_M_N
ds_grid_desc_g_m_n
,
EGridDesc_G_M_N
e_grid_desc_g_m_n
)
:
batch_stride_A_
(
batch_stride_A
),
batch_stride_B_
(
batch_stride_B
),
ds_grid_desc_g_m_n_
(
ds_grid_desc_g_m_n
),
e_grid_desc_g_m_n_
(
e_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_A_
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_B_
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
static_cast
<
long_index_t
>
(
g_idx
)
*
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
}
private:
index_t
batch_stride_A_
;
index_t
batch_stride_B_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_K0_M_K1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
AGridDesc_M_K
{}));
using
BGridDesc_K0_N_K1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
BGridDesc_N_K
{}));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{},
b_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{},
a_batch_stride_
{
a_gs_ms_ks_strides
[
NumDimG
-
1
]},
b_batch_stride_
{
b_gs_ns_ks_strides
[
NumDimG
-
1
]},
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
});
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
a_kz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
];
b_nz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
];
b_kz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
ds_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseOp
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// Idle
index_t
M01_
;
index_t
N01_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t
a_mz_stride_
;
index_t
a_kz_stride_
;
index_t
b_nz_stride_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_mz_stride_
;
index_t
e_nz_stride_
;
index_t
a_batch_stride_
;
index_t
b_batch_stride_
;
// Batch Offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
DeviceOp
::
AGridDesc_K0_M_K1
,
DeviceOp
::
BGridDesc_K0_N_K1
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
return
false
;
}
// check vector access
static_assert
((
ABlockTransferSrcVectorDim
==
1
||
ABlockTransferSrcVectorDim
==
2
)
&&
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of B: could be on N or BK1 dimension
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetLength
(
I3
)
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid_d_access
=
false
;
}
});
if
(
valid_d_access
==
false
)
{
return
false
;
}
// vector memory access of E: always on NPerBlock dimension
if
(
!
((
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I3
)
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
)
||
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
{
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ks_lengths
,
b_gs_ns_ks_lengths
,
ds_gs_ms_ns_lengths
,
e_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ks_lengths
,
b_gs_ns_ks_lengths
,
ds_gs_ms_ns_lengths
,
e_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceBatchedContractionMultipleD_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerWMMA
<<
", "
<<
NPerWMMA
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
a72a5762
...
...
@@ -111,6 +111,7 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c0de_element_op
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
a72a5762
...
...
@@ -586,6 +586,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
return
false
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
View file @
a72a5762
...
...
@@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl
return
(
false
);
}
if
(
p_arg_
->
x_lds_size_
>=
65536
)
{
return
(
false
);
}
return
true
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
a72a5762
...
...
@@ -669,6 +669,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
}
if
(
arg
.
p_workspace_e_grid_
==
nullptr
||
arg
.
p_workspace_mean_
==
nullptr
||
arg
.
p_workspace_var_
==
nullptr
||
arg
.
p_workspace_count_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
gemm_e_grid_desc_m_n_
);
...
...
@@ -939,7 +942,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
}
}
return
true
;
return
GridwiseGemmWelford
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
gemm_e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
...
...
@@ -1055,7 +1062,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
<<
GemmKPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PostShuffleThreadClusterSize_M_N
::
At
(
I0
)
<<
", "
<<
PostShuffleThreadClusterSize_M_N
::
At
(
I1
)
<<
", "
<<
LayernormThreadClusterSize_M_N
::
At
(
I0
)
<<
", "
<<
LayernormThreadClusterSize_M_N
::
At
(
I1
)
<<
", "
<<
LayernormThreadSliceSize_M
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGemmMultipleD_Wmma_CShuffle
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleD_Wmma_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
#endif
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
ELayout_
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
{
const
auto
e_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout_
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout_
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideE
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
Ms
,
const
std
::
array
<
index_t
,
NumDTensor
>&
Ns
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
Ms
[
i
],
Ns
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
StrideDs
[
i
]);
});
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_ctile_map_
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseOp
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// Idle
index_t
M01_
;
index_t
N01_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
true
>
;
// Last Option is W/O
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmMultipleD_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerWMMA
<<
", "
<<
NPerWMMA
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_B_K0_M0_M1_K1
,
typename
BGridDesc_B_K0_N0_N1_K1
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_dlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
batch_count
,
const
AGridDesc_B_K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1
,
const
BGridDesc_B_K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc_kbatch_k0_m0_m1_k1
,
b_grid_desc_kbatch_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
typename
M1N1ThreadClusterM1Xs
,
typename
M1N1ThreadClusterN1Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
CDataType
=
WeiDataType
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CElementwiseOperation
=
WeiElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
// Bytes per 32 lds bank: 32 * 4 bytes
static
constexpr
auto
BankLength
=
128
;
static
constexpr
auto
ElePerBank
=
BankLength
/
sizeof
(
ADataType
);
// M1 & M0
static
constexpr
auto
ABlockLdsM1PerBlock
=
ElePerBank
/
K1
;
static
constexpr
auto
ABlockLdsM0PerBlock
=
MPerBlock
/
ABlockLdsM1PerBlock
;
static
constexpr
auto
ABlockLdsM1Padding
=
4
;
// N1 & N0
static
constexpr
auto
BBlockLdsN1PerBlock
=
ElePerBank
/
K1
;
static
constexpr
auto
BBlockLdsN0PerBlock
=
NPerBlock
/
BBlockLdsN1PerBlock
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weights tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Hi
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
1
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_B_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_B_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
GridwiseGemmDl_bkm_bkn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_B_K0_M_K1
,
BGridDesc_B_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
// Argument
using
AGridDesc_B_K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_B_K0_M0_M1_K1
(
AGridDesc_B_K0_M_K1
{}));
using
BGridDesc_B_K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_B_K0_N0_N1_K1
(
BGridDesc_B_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
CGridDesc_M_N
{},
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
G
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_kbatch_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_B_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1_
);
b_grid_desc_kbatch_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_B_K0_N0_N1_K1
(
b_grid_desc_kbatch_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
ck
::
index_t
M01
=
1
;
ck
::
index_t
N01
=
1
;
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_B_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_B_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_B_K0_M0_M1_K1
a_grid_desc_kbatch_k0_m0_m1_k1_
;
BGridDesc_B_K0_N0_N1_K1
b_grid_desc_kbatch_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
// DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
ShowInfo
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
.
value
;
const
auto
kernel
=
kernel_batched_gemm_dlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_B_K0_M0_M1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
,
has_double_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
,
arg
.
b_grid_desc_kbatch_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m0_m1_k1_
.
GetLength
(
I1
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
return
false
;
}
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// matrix A
{
auto
srcVectorLengths
=
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
{};
if
(
srcVectorLengths
[
I2
]
!=
1
||
srcVectorLengths
[
I3
]
!=
1
)
{
return
false
;
}
if
(
K1
%
srcVectorLengths
[
I4
]
!=
0
||
K0PerBlock
%
srcVectorLengths
[
I1
]
!=
0
)
{
return
false
;
}
const
index_t
K
=
arg
.
Conv_K_
;
if
(
K
%
(
srcVectorLengths
[
I1
]
*
srcVectorLengths
[
I4
])
!=
0
)
{
return
false
;
}
}
// matrix B
{
auto
srcLoadLenghts
=
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
{};
auto
srcVectorLengths
=
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
{};
if
(
srcVectorLengths
[
I1
]
!=
1
||
srcVectorLengths
[
I4
]
!=
1
)
{
return
false
;
}
if
(
srcLoadLenghts
[
I2
]
%
srcVectorLengths
[
I2
]
!=
0
||
srcLoadLenghts
[
I3
]
%
srcVectorLengths
[
I3
]
!=
0
)
{
return
false
;
}
const
index_t
C
=
arg
.
Conv_K_
;
if
(
C
%
(
srcVectorLengths
[
I2
]
*
srcVectorLengths
[
I3
])
!=
0
)
{
return
false
;
}
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
==
0
))
{
std
::
cout
<<
"Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
<<
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
<<
std
::
endl
;
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
0 → 100644
View file @
a72a5762
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
}
// namespace
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// Assume:
// AK1 == BK1
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
:
public
DeviceGroupedConvFwdMultipleD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
template
<
typename
BLay
>
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
return
out_gemmm_gemmn_desc
;
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_g_n_k_wos_lengths
[
i
],
ds_g_n_k_wos_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK1
=
K1
;
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK1
=
K1
;
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
index_t
M01
,
index_t
N01
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
)},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
});
// D desc
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
);
// populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseOp
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
if
(
get_device_name
()
==
"gfx1100"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
2
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
ConvStride
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
X
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
2
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
return
false
;
}
}
}
// check vector access of A
// FIXME: layout
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
G_NW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of B
// FIXME: layout
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
ctc
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
ctc
::
GKXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKYXC
>
||
is_same_v
<
BLayout
,
ctc
::
GKZYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXGC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check vector access of Ds
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// FIXME: layout
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNHWK
>
||
is_same_v
<
DLayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
GK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid
=
false
;
}
}
else
{
valid
=
false
;
}
});
if
(
!
valid
)
{
return
false
;
}
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// check Gridwise GEMM
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment