Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed305f6b
"vscode:/vscode.git/clone" did not exist on "b885ec84df0523a2bdcd74defd7c96ee104e616d"
Commit
ed305f6b
authored
Sep 28, 2023
by
Umang Yadav
Browse files
formatting
parent
9f4e3544
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
311 additions
and
304 deletions
+311
-304
include/ck/ck.hpp
include/ck/ck.hpp
+4
-4
include/ck/tensor_description/tensor_adaptor.hpp
include/ck/tensor_description/tensor_adaptor.hpp
+14
-10
include/ck/tensor_description/tensor_space_filling_curve.hpp
include/ck/tensor_description/tensor_space_filling_curve.hpp
+4
-2
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+18
-18
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+14
-14
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+19
-18
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+17
-17
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+19
-19
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+23
-23
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+93
-92
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+1
-1
No files found.
include/ck/ck.hpp
View file @
ed305f6b
...
@@ -50,9 +50,9 @@
...
@@ -50,9 +50,9 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
// for GPU code
defined(__gfx942__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__)
// for GPU code
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
...
@@ -86,7 +86,7 @@
...
@@ -86,7 +86,7 @@
#endif
#endif
// WMMA instruction
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__
// for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
...
@@ -107,7 +107,7 @@
...
@@ -107,7 +107,7 @@
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else
// for GPU code
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#endif
...
...
include/ck/tensor_description/tensor_adaptor.hpp
View file @
ed305f6b
...
@@ -108,13 +108,13 @@ struct TensorAdaptor
...
@@ -108,13 +108,13 @@ struct TensorAdaptor
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
{
constexpr
auto
all_low_dim_ids
=
constexpr
auto
all_low_dim_ids
=
unpack
(
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
constexpr
auto
all_up_dim_ids
=
unpack
(
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
...
@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
// sequence in, sequence out
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
// shift hidden id so every dim id is unique
...
@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
});
return
low_dim_hidden_ids_1_mod_
;
return
low_dim_hidden_ids_1_mod_
;
}();
}
();
return
generate_sequence_v2
(
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
...
@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
// sequence in, constexpr tuple out
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
// shift hidden id
...
@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
});
return
up_dim_hidden_ids_1_mod_
;
return
up_dim_hidden_ids_1_mod_
;
}();
}
();
// constexpr tuple to sequence
// constexpr tuple to sequence
return
generate_sequence_v2
(
return
generate_sequence_v2
(
...
...
include/ck/tensor_description/tensor_space_filling_curve.hpp
View file @
ed305f6b
...
@@ -94,8 +94,10 @@ struct SpaceFillingCurve
...
@@ -94,8 +94,10 @@ struct SpaceFillingCurve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
// All constexpr variables have to be captured by VALUE.
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
{
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
{
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
auto
res
=
idx_1d
.
value
;
auto
res
=
idx_1d
.
value
;
auto
id
=
0
;
auto
id
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
ed305f6b
...
@@ -14,8 +14,8 @@ namespace device {
...
@@ -14,8 +14,8 @@ namespace device {
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
BaseArgument
(
const
BaseArgument
&
)
=
default
;
BaseArgument
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
virtual
~
BaseArgument
()
{}
virtual
~
BaseArgument
()
{}
...
@@ -26,8 +26,8 @@ struct BaseArgument
...
@@ -26,8 +26,8 @@ struct BaseArgument
#ifndef __HIPCC_RTC__
#ifndef __HIPCC_RTC__
struct
BaseInvoker
struct
BaseInvoker
{
{
BaseInvoker
()
=
default
;
BaseInvoker
()
=
default
;
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
BaseInvoker
&
operator
=
(
const
BaseInvoker
&
)
=
default
;
virtual
float
Run
(
const
BaseArgument
*
,
const
StreamConfig
&
=
StreamConfig
{})
virtual
float
Run
(
const
BaseArgument
*
,
const
StreamConfig
&
=
StreamConfig
{})
...
@@ -41,13 +41,12 @@ struct BaseInvoker
...
@@ -41,13 +41,12 @@ struct BaseInvoker
struct
BaseOperator
struct
BaseOperator
{
{
BaseOperator
()
=
default
;
BaseOperator
()
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
#ifndef __HIPCC_RTC__
#ifndef __HIPCC_RTC__
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
@@ -66,7 +65,7 @@ struct BaseOperator
...
@@ -66,7 +65,7 @@ struct BaseOperator
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
{
{
//assert(p_arg);
//
assert(p_arg);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -38,25 +38,25 @@ template <typename GridwiseGemm,
...
@@ -38,25 +38,25 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
ed305f6b
...
@@ -60,21 +60,21 @@ template <typename GridwiseGemm,
...
@@ -60,21 +60,21 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_e_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
kernel_batched_gemm_e_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -41,25 +41,26 @@ template <typename GridwiseGemm,
...
@@ -41,25 +41,26 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
AElementwiseOperation
a_element_op
,
FloatC
*
__restrict__
p_c_grid
,
const
BElementwiseOperation
b_element_op
,
const
AElementwiseOperation
a_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
BElementwiseOperation
b_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
CElementwiseOperation
c_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
CElementwiseOperation
c_element_op
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
Block2CTileMap
block_2_ctile_map
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
index_t
batch_count
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
ed305f6b
...
@@ -63,24 +63,24 @@ template <typename GridwiseGemm,
...
@@ -63,24 +63,24 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
kernel_batched_gemm_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
ed305f6b
...
@@ -52,23 +52,23 @@ template <typename GridwiseGemm,
...
@@ -52,23 +52,23 @@ template <typename GridwiseGemm,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dl_multiple_d
(
kernel_gemm_dl_multiple_d
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -41,32 +41,32 @@ template <typename GridwiseGemm,
...
@@ -41,32 +41,32 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_gemm_xdl_cshuffle_v1
(
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
const
A0B0B1DataType
*
__restrict__
p_b0_grid
,
const
A0B0B1DataType
*
__restrict__
p_b0_grid
,
D0sPointer
p_d0s_grid
,
D0sPointer
p_d0s_grid
,
const
A0B0B1DataType
*
__restrict__
p_b1_grid
,
const
A0B0B1DataType
*
__restrict__
p_b1_grid
,
D1sPointer
p_d1s_grid
,
D1sPointer
p_d1s_grid
,
E1DataType
*
__restrict__
p_e1_grid
,
E1DataType
*
__restrict__
p_e1_grid
,
const
A0ElementwiseOperation
a0_element_op
,
const
A0ElementwiseOperation
a0_element_op
,
const
B0ElementwiseOperation
b0_element_op
,
const
B0ElementwiseOperation
b0_element_op
,
const
CDE0ElementwiseOperation
cde0_element_op
,
const
CDE0ElementwiseOperation
cde0_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CDE1ElementwiseOperation
cde1_element_op
,
const
CDE1ElementwiseOperation
cde1_element_op
,
const
A0GridDesc_AK0_M_AK1
a0_grid_desc_ak0_m_ak1
,
const
A0GridDesc_AK0_M_AK1
a0_grid_desc_ak0_m_ak1
,
const
B0GridDesc_BK0_N_BK1
b0_grid_desc_bk0_n_bk1
,
const
B0GridDesc_BK0_N_BK1
b0_grid_desc_bk0_n_bk1
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
const
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2E1TileMap
block_2_e1tile_map
,
const
Block2E1TileMap
block_2_e1tile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -38,26 +38,26 @@ template <typename GridwiseGemm,
...
@@ -38,26 +38,26 @@ template <typename GridwiseGemm,
bool
HasMainK0BlockLoop
>
bool
HasMainK0BlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_reduce_xdl_cshuffle_v1
(
kernel_batched_gemm_reduce_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
ReducePtrsGlobal
p_reduces_grid
,
ReducePtrsGlobal
p_reduces_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
ReduceInElementwiseOperations
reduce_in_element_ops
,
const
ReduceInElementwiseOperations
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
reduce_out_element_ops
,
const
ReduceAccElementwiseOperations
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -42,30 +42,30 @@ template <typename GridwiseGemm,
...
@@ -42,30 +42,30 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
D0sPointer
p_d0s_grid
,
D0sPointer
p_d0s_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
C0DEElementwiseOperation
c0de_element_op
,
const
C0DEElementwiseOperation
c0de_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
C1DEElementwiseOperation
c1de_element_op
,
const
C1DEElementwiseOperation
c1de_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -40,27 +40,27 @@ template <typename GridwiseGemm,
...
@@ -40,27 +40,27 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
{
// check vector load/store
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
struct
Descriptor
{
{
template
<
class
AGridDescriptor
>
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
...
@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
AK0
=
K
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
return
transform_tensor_descriptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
a_grid_desc_m_k
,
make_pass_through_transform
(
M
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
BGridDescriptor
>
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
...
@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
BK0
=
K
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
return
transform_tensor_descriptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
b_grid_desc_n_k
,
make_pass_through_transform
(
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
B1GridDescriptor
>
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
...
@@ -889,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -889,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
CGridDescriptor
>
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
}
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
...
@@ -979,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -979,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CGridDesc_M_N
c_grid_desc_m_n
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op
;
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
BElementwiseOperation
b_element_op
;
...
@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
...
@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_element_op
{
b_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
{
}
}
constexpr
bool
IsValid
()
const
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
{
return
is_valid
;
}
};
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
...
@@ -1038,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1038,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BDesc
b
,
BDesc
b
,
B1Desc
b1
,
B1Desc
b1
,
CDesc
c
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
...
@@ -1061,41 +1060,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1061,41 +1060,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_b_grid
,
p_a_grid
,
p_b1_grid
,
p_b_grid
,
p_c_grid
,
p_b1_grid
,
p_shared_block
,
p_c_grid
,
desc
.
a_element_op
,
p_shared_block
,
desc
.
b_element_op
,
desc
.
a_element_op
,
acc_element_op
,
desc
.
b_element_op
,
desc
.
b1_element_op
,
acc_element_op
,
desc
.
c_element_op
,
desc
.
b1_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
c_element_op
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
block_2_ctile_map
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
c0_matrix_mask
);
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
else
else
{
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_b_grid
,
p_a_grid
,
p_b1_grid
,
p_b_grid
,
p_c_grid
,
p_b1_grid
,
p_shared_block
,
p_c_grid
,
desc
.
a_element_op
,
p_shared_block
,
desc
.
b_element_op
,
desc
.
a_element_op
,
acc_element_op
,
desc
.
b_element_op
,
desc
.
b1_element_op
,
acc_element_op
,
desc
.
c_element_op
,
desc
.
b1_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
c_element_op
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
block_2_ctile_map
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
c0_matrix_mask
);
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
ed305f6b
...
@@ -48,9 +48,9 @@ namespace device {
...
@@ -48,9 +48,9 @@ namespace device {
template
<
typename
DeviceOp
,
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
template
<
typename
DeviceOp
,
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -34,23 +34,23 @@ template <typename GridwiseGemm,
...
@@ -34,23 +34,23 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
...
@@ -404,7 +404,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -404,7 +404,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
7
,
// CThreadTransferSrcDstVectorDim,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
...
@@ -436,7 +436,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -436,7 +436,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BlockSize
,
BlockSize
,
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
// TODO: Add ShuffleType for DeviceConv2d
CDataType
,
// TODO: Add ShuffleType for DeviceConv2d
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
...
@@ -354,7 +354,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -354,7 +354,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
2
,
// BBlockTransferSrcVectorDim,
2
,
// BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
7
,
// CThreadTransferSrcDstVectorDim,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
ed305f6b
...
@@ -37,23 +37,23 @@ template <typename GridwiseGemm,
...
@@ -37,23 +37,23 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r3_for_conv3d
(
kernel_gemm_xdlops_v2r3_for_conv3d
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
num_batches
,
const
index_t
num_batches
,
const
index_t
a_batch_stride
,
const
index_t
a_batch_stride
,
const
index_t
b_batch_stride
,
const
index_t
b_batch_stride
,
const
index_t
c_batch_stride
,
const
index_t
c_batch_stride
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
ed305f6b
...
@@ -1005,7 +1005,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1005,7 +1005,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
7
,
// CThreadTransferSrcDstVectorDim,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment