Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e0041ad8
Commit
e0041ad8
authored
May 29, 2023
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/drop_cshuffle
parents
3239201e
ac9e01e2
Changes
361
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1380 additions
and
159 deletions
+1380
-159
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
+44
-0
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+28
-8
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+4
-6
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...ice/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+992
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+40
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+184
-108
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+9
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+18
-3
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+9
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+14
-2
No files found.
Too many changes to show.
To preserve performance only
361 of 361+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
e0041ad8
...
...
@@ -14,9 +14,9 @@ namespace device {
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
YDataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalization
:
public
BaseOperator
...
...
@@ -28,14 +28,14 @@ struct DeviceNormalization : public BaseOperator
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_savedMean
,
void
*
p_savedInvVar
,
Acc
ElementwiseOperation
acc
_elementwise_op
)
=
0
;
Y
ElementwiseOperation
y
_elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
YDataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationPtr
=
std
::
unique_ptr
<
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
Acc
ElementwiseOperation
,
Y
ElementwiseOperation
,
Rank
,
NumReduceDim
>>
;
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
e0041ad8
...
...
@@ -4,7 +4,6 @@
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
...
...
include/ck/tensor_operation/gpu/device/device_pool
2d
_fwd.hpp
→
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
View file @
e0041ad8
...
...
@@ -3,8 +3,7 @@
#pragma once
#include <iostream>
#include <array>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
...
...
@@ -13,28 +12,33 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
struct
DevicePool2dFwd
:
public
BaseOperator
template
<
index_t
InOutRank
,
index_t
WindowRank
,
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DevicePoolFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
=
0
;
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
using
DevicePool2dFwdPtr
=
std
::
unique_ptr
<
DevicePool2dFwd
<
ReduceOpId
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
e0041ad8
...
...
@@ -13,10 +13,16 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
struct
DeviceReduce
:
public
BaseOperator
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
...
...
@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>>
;
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
e0041ad8
...
...
@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
...
...
@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -56,7 +56,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
...
...
@@ -938,7 +939,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
0 → 100644
View file @
e0041ad8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
// stride_i = stride_{i - 1} * extent_{i - 1}
// So tensor
// [G0, G1, G2, M, N]
// transposed into tensor
// [G0, G2, G1, M, N]
// with strides
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedContractionMultipleD_Wmma_CShuffle
:
public
DeviceBatchedContractionMultipleD
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedContractionMultipleD_Wmma_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
a_gs_ms_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_gs_ms_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_gs_ms_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
b_gs_ns_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_gs_ns_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_gs_ns_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
e_ms_ns_strides
[
Number
<
NumDimM
-
1
>
{}],
e_ms_ns_strides
[
Number
<
NumDimM
+
NumDimN
-
1
>
{}]));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_gs_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_gs_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
e_gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
else
{
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
e_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
static
auto
MakeDsGridDescriptor_G_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
DsGridDesc_G_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_G_M_N
({},
{}))
>
;
using
EGridDesc_G_M_N
=
decltype
(
MakeEGridDescriptor_G_M_N
({},
{}));
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
batch_stride_A
,
index_t
batch_stride_B
,
DsGridDesc_G_M_N
ds_grid_desc_g_m_n
,
EGridDesc_G_M_N
e_grid_desc_g_m_n
)
:
batch_stride_A_
(
batch_stride_A
),
batch_stride_B_
(
batch_stride_B
),
ds_grid_desc_g_m_n_
(
ds_grid_desc_g_m_n
),
e_grid_desc_g_m_n_
(
e_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_A_
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_B_
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
static_cast
<
long_index_t
>
(
g_idx
)
*
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
g_idx
)
*
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
}
private:
index_t
batch_stride_A_
;
index_t
batch_stride_B_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_K0_M_K1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
AGridDesc_M_K
{}));
using
BGridDesc_K0_N_K1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
BGridDesc_N_K
{}));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
// DataType Family
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// Tiling Family
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerWMMA
,
NPerWMMA
,
K1
,
MRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{},
b_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{},
a_batch_stride_
{
a_gs_ms_ks_strides
[
NumDimG
-
1
]},
b_batch_stride_
{
b_gs_ns_ks_strides
[
NumDimG
-
1
]},
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
});
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
a_kz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
];
b_nz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
];
b_kz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
ds_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseOp
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// Idle
index_t
M01_
;
index_t
N01_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t
a_mz_stride_
;
index_t
a_kz_stride_
;
index_t
b_nz_stride_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_mz_stride_
;
index_t
e_nz_stride_
;
index_t
a_batch_stride_
;
index_t
b_batch_stride_
;
// Batch Offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
DeviceOp
::
AGridDesc_K0_M_K1
,
DeviceOp
::
BGridDesc_K0_N_K1
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
return
false
;
}
// check vector access
static_assert
((
ABlockTransferSrcVectorDim
==
1
||
ABlockTransferSrcVectorDim
==
2
)
&&
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of B: could be on N or BK1 dimension
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetLength
(
I3
)
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid_d_access
=
false
;
}
});
if
(
valid_d_access
==
false
)
{
return
false
;
}
// vector memory access of E: always on NPerBlock dimension
if
(
!
((
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I3
)
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
)
||
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
{
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ks_lengths
,
b_gs_ns_ks_lengths
,
ds_gs_ms_ns_lengths
,
e_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ks_lengths
,
b_gs_ns_ks_lengths
,
ds_gs_ms_ns_lengths
,
e_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceBatchedContractionMultipleD_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerWMMA
<<
", "
<<
NPerWMMA
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -56,7 +56,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
...
...
@@ -839,7 +840,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
e0041ad8
...
...
@@ -74,7 +74,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -60,7 +60,8 @@ __global__ void
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -588,7 +589,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
e0041ad8
...
...
@@ -83,7 +83,8 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -579,7 +580,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -68,7 +68,8 @@ __global__ void
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -579,6 +580,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s
,
BatchStrideE1
}
{
#if DEBUG_LOG
std
::
cout
<<
"a0_grid_desc_m_k_{"
<<
a0_grid_desc_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a0_grid_desc_m_k_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b0_grid_desc_n_k_{"
<<
b0_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -601,6 +603,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<<
std
::
endl
;
std
::
cout
<<
"e1_grid_desc_m_n_{"
<<
e1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
e1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
#endif
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
using
D0Layout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sLayout
>>
;
...
...
@@ -786,9 +789,44 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
true
;
}
// check if DsLayout is supported
template
<
typename
RefLayout
,
typename
DsLayout
,
const
index_t
NumDTensor
>
static
bool
CheckDLayout
()
{
static
bool
valid
=
true
;
// iterate over DLayout tuple
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid
=
valid
&&
is_same_v
<
RefLayout
,
DLayout
>
;
});
return
valid
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
// Check supported layouts
// A0 - Row
// B0 - Col
// D0s - Rows
// B1 - Row or Col
// D1s - Rows
// E1 - Row
if
(
!
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
A0Layout
>
&&
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
B0Layout
>
&&
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
D0sLayout
,
NumD0Tensor
>
()
&&
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
||
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>
)
&&
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
D1sLayout
,
NumD1Tensor
>
()
&&
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
E1Layout
>
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -59,7 +59,8 @@ __global__ void
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -657,7 +658,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
"arg.Batch_ = "
<<
arg
.
Batch_
<<
std
::
endl
;
...
...
@@ -674,8 +675,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}"
<< std::endl;
std
::
cout
<<
"arg.reduce_grid_desc_m_{ "
<<
arg
.
reduce_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
}
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_
multiple_d_
softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -25,15 +25,17 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
D0sPointer
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
Acc
ElementwiseOperation
,
typename
C0DE
ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
1DE
ElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
...
...
@@ -47,22 +49,26 @@ __global__ void
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
D0sPointer
p_d0s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
Acc
ElementwiseOperation
acc
_element_op
,
const
C0DE
ElementwiseOperation
c0de
_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C
1DE
ElementwiseOperation
c
1de
_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -77,20 +83,28 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_for
<
0
,
p_d0s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
,
In
)));
p_d0s_grid
(
In
)
=
p_d0s_grid
(
In
)
+
d0_batch_offset
;
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_d0s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
acc
_element_op
,
c0de
_element_op
,
b1_element_op
,
c_element_op
,
c
1de
_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
block_2_ctile_map
,
c0_matrix_mask
);
#else
...
...
@@ -98,15 +112,17 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc
_element_op
;
ignore
=
c0de
_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c
1de
_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
...
...
@@ -126,15 +142,15 @@ template <index_t NumDimG,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0Bia
sDataType
,
typename
Acc1Bia
sDataType
,
typename
D0
sDataType
,
typename
D1
sDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
Acc
ElementwiseOperation
,
typename
C0DE
ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
1DE
ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
...
...
@@ -192,23 +208,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BDataType
,
B1DataType
,
CDataType
,
Acc0Bia
sDataType
,
Acc1Bia
sDataType
,
D0
sDataType
,
D1
sDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
Acc
ElementwiseOperation
,
C0DE
ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
C
1DE
ElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
Num
Acc0Bias
=
Acc0Bia
sDataType
::
Size
();
static
constexpr
index_t
Num
Acc1Bias
=
Acc1Bia
sDataType
::
Size
();
static
constexpr
index_t
Num
D0Tensor
=
D0
sDataType
::
Size
();
static
constexpr
index_t
Num
D1Tensor
=
D1
sDataType
::
Size
();
// TODO ANT: implement bias combination
static_assert
(
Num
Acc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
Num
D1Tensor
==
0
,
"
Gemm1
Bias addition is unimplemented"
);
#if 0
// TODO ANT: use alias
...
...
@@ -261,14 +277,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number
<
B1K1
>
{});
}
static
auto
MakeD0sGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
i
],
acc0_biases_gs_ms_ns_strides
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
static
auto
MakeD0sGridDescriptor_G_M_N
(
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
[
i
],
acc0_biases_gs_ms_ns_strides
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
C
1
GridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
C1GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
D0sGridDesc_M_N
=
decltype
(
MakeD0sGridDescriptor_M_N
({},
{}));
using
D0sGridDesc_G_M_N
=
decltype
(
MakeD0sGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
...
...
@@ -288,11 +330,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
const
C1GridDesc_G_M_N
&
c1_grid_desc_g_m_n
,
const
D0sGridDesc_G_M_N
&
d0s_grid_desc_g_m_n
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c1_grid_desc_g_m_n_
(
c1_grid_desc_g_m_n
),
d0s_grid_desc_g_m_n_
(
d0s_grid_desc_g_m_n
)
{
}
...
...
@@ -313,32 +357,42 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
c1_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
template
<
index_t
I
>
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
,
Number
<
I
>
d0_idx
)
const
{
return
d0s_grid_desc_g_m_n_
[
d0_idx
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
C1GridDesc_G_M_N
c1_grid_desc_g_m_n_
;
D0sGridDesc_G_M_N
d0s_grid_desc_g_m_n_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemm
MultipleD
SoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
D0sDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
Acc
ElementwiseOperation
,
C0DE
ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
C
1DE
ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
C1GridDesc_M_N
,
D0sGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -395,8 +449,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
BDataType
*
p_b_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -405,44 +459,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
Acc
ElementwiseOperation
acc
_element_op
,
C0DE
ElementwiseOperation
c0de
_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
C
1DE
ElementwiseOperation
c
1de
_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_d0s_grid_
{},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c
1
_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
c1_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
d0s_grid_desc_g_m_n_
{
DeviceOp
::
MakeD0sGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
)},
c1_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c1_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc
_element_op_
{
acc
_element_op
},
c0de
_element_op_
{
c0de
_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c
1de
_element_op_
{
c
1de
_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
...
@@ -456,27 +514,39 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
}
batch_count_
{
c1_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c1_grid_desc_g_m_n_
,
d0s_grid_desc_g_m_n_
}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
// D0 pointer
p_d0s_grid_
(
i
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
[
i
]);
});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
c
1
_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c1_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c1_grid_desc_m_n_
);
D0sGridDesc_M_N
d0s_grid_desc_m_n
{
DeviceOp
::
MakeD0sGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
)};
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0s_grid_desc_m_n
);
}
}
...
...
@@ -485,19 +555,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_n_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"c1_grid_desc_g_m_n_: "
<<
c1_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c1_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c1_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
}
// pointers
...
...
@@ -505,18 +571,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C
1
GridDesc_M_N
c
1
_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
C1GridDesc_G_M_N
c1_grid_desc_g_m_n_
;
D0sGridDesc_G_M_N
d0s_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -524,9 +595,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
Acc
ElementwiseOperation
acc
_element_op_
;
C0DE
ElementwiseOperation
c0de
_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
C
1DE
ElementwiseOperation
c
1de
_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
...
...
@@ -555,7 +626,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c
1
_grid_desc_m_n_
)
*
arg
.
batch_count_
;
// Gemm0_K
const
auto
K
=
...
...
@@ -568,15 +639,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
typename
GridwiseGemm
::
D0sGridPointer
,
AElementwiseOperation
,
BElementwiseOperation
,
Acc
ElementwiseOperation
,
C0DE
ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
C
1DE
ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
...
...
@@ -591,15 +664,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_d0s_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc
_element_op_
,
arg
.
c0de
_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c
1de
_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c1_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
...
...
@@ -636,11 +711,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
@@ -648,9 +724,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
c_g
=
arg
.
c
1
_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c
1
_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c
1
_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
...
...
@@ -700,7 +776,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c
1
_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -715,8 +791,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
BDataType
*
p_b
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -725,17 +801,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
Acc
ElementwiseOperation
acc
_element_op
,
C0DE
ElementwiseOperation
c0de
_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
C
1DE
ElementwiseOperation
c
1de
_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -757,9 +833,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
a_element_op
,
b_element_op
,
acc
_element_op
,
c0de
_element_op
,
b1_element_op
,
c_element_op
};
c
1de
_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -771,8 +847,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b1
,
void
*
p_c
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -781,17 +857,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
Acc
ElementwiseOperation
acc
_element_op
,
C0DE
ElementwiseOperation
c0de
_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
C
1DE
ElementwiseOperation
c
1de
_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -813,9 +889,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc1_biases_gs_ms_gemm1ns_strides
,
a_element_op
,
b_element_op
,
acc
_element_op
,
c0de
_element_op
,
b1_element_op
,
c_element_op
);
c
1de
_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -62,7 +62,8 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -612,7 +613,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
e0041ad8
...
...
@@ -75,7 +75,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -412,7 +413,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e0041ad8
...
...
@@ -52,7 +52,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -581,7 +582,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
return
false
;
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
e0041ad8
...
...
@@ -488,7 +488,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
using
Argument
=
DeviceOp
::
Argument
;
void
ShowInfo
(
const
Argument
&
arg
)
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -508,7 +508,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
...
...
@@ -774,7 +777,19 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
e0041ad8
...
...
@@ -549,7 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -822,7 +822,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
e0041ad8
...
...
@@ -644,7 +644,7 @@ struct
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
@@ -956,7 +956,19 @@ struct
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
// clang-format on
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment