Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed305f6b
Commit
ed305f6b
authored
Sep 28, 2023
by
Umang Yadav
Browse files
formatting
parent
9f4e3544
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
311 additions
and
304 deletions
+311
-304
include/ck/ck.hpp
include/ck/ck.hpp
+4
-4
include/ck/tensor_description/tensor_adaptor.hpp
include/ck/tensor_description/tensor_adaptor.hpp
+14
-10
include/ck/tensor_description/tensor_space_filling_curve.hpp
include/ck/tensor_description/tensor_space_filling_curve.hpp
+4
-2
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+18
-18
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+14
-14
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+19
-18
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+17
-17
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+19
-19
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+23
-23
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+93
-92
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+1
-1
No files found.
include/ck/ck.hpp
View file @
ed305f6b
include/ck/tensor_description/tensor_adaptor.hpp
View file @
ed305f6b
...
@@ -108,12 +108,12 @@ struct TensorAdaptor
...
@@ -108,12 +108,12 @@ struct TensorAdaptor
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
{
constexpr
auto
all_low_dim_ids
=
constexpr
auto
all_low_dim_ids
=
unpack
(
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
constexpr
auto
all_up_dim_ids
=
unpack
(
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
...
@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
// sequence in, sequence out
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
// shift hidden id so every dim id is unique
...
@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
});
return
low_dim_hidden_ids_1_mod_
;
return
low_dim_hidden_ids_1_mod_
;
}();
}
();
return
generate_sequence_v2
(
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
...
@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
// sequence in, constexpr tuple out
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
// shift hidden id
...
@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
...
@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
});
});
return
up_dim_hidden_ids_1_mod_
;
return
up_dim_hidden_ids_1_mod_
;
}();
}
();
// constexpr tuple to sequence
// constexpr tuple to sequence
return
generate_sequence_v2
(
return
generate_sequence_v2
(
...
...
include/ck/tensor_description/tensor_space_filling_curve.hpp
View file @
ed305f6b
...
@@ -94,8 +94,10 @@ struct SpaceFillingCurve
...
@@ -94,8 +94,10 @@ struct SpaceFillingCurve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
// All constexpr variables have to be captured by VALUE.
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
{
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
{
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
auto
res
=
idx_1d
.
value
;
auto
res
=
idx_1d
.
value
;
auto
id
=
0
;
auto
id
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
ed305f6b
...
@@ -47,7 +47,6 @@ struct BaseOperator
...
@@ -47,7 +47,6 @@ struct BaseOperator
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
#ifndef __HIPCC_RTC__
#ifndef __HIPCC_RTC__
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
@@ -66,7 +65,7 @@ struct BaseOperator
...
@@ -66,7 +65,7 @@ struct BaseOperator
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
{
{
//assert(p_arg);
//
assert(p_arg);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
ed305f6b
...
@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
...
@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_e_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
kernel_batched_gemm_e_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -41,9 +41,10 @@ template <typename GridwiseGemm,
...
@@ -41,9 +41,10 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
ed305f6b
...
@@ -63,7 +63,7 @@ template <typename GridwiseGemm,
...
@@ -63,7 +63,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
kernel_batched_gemm_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
ed305f6b
...
@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
...
@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dl_multiple_d
(
kernel_gemm_dl_multiple_d
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -41,7 +41,7 @@ template <typename GridwiseGemm,
...
@@ -41,7 +41,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_gemm_xdl_cshuffle_v1
(
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
const
A0B0B1DataType
*
__restrict__
p_a0_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
bool
HasMainK0BlockLoop
>
bool
HasMainK0BlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_reduce_xdl_cshuffle_v1
(
kernel_batched_gemm_reduce_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -42,7 +42,7 @@ template <typename GridwiseGemm,
...
@@ -42,7 +42,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
{
// check vector load/store
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
struct
Descriptor
{
{
template
<
class
AGridDescriptor
>
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
...
@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
AK0
=
K
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
BGridDescriptor
>
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
...
@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
BK0
=
K
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
B1GridDescriptor
>
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
...
@@ -894,21 +897,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -894,21 +897,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
class
CGridDescriptor
>
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
}
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
...
@@ -979,7 +980,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -979,7 +980,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CGridDesc_M_N
c_grid_desc_m_n
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op
;
AElementwiseOperation
a_element_op
;
...
@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
...
@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_element_op
{
b_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
{
}
}
constexpr
bool
IsValid
()
const
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
{
return
is_valid
;
}
};
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
...
@@ -1061,7 +1060,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1061,7 +1060,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -1080,7 +1080,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1080,7 +1080,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
else
else
{
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
ed305f6b
...
@@ -48,7 +48,7 @@ namespace device {
...
@@ -48,7 +48,7 @@ namespace device {
template
<
typename
DeviceOp
,
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
template
<
typename
DeviceOp
,
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
ed305f6b
...
@@ -34,7 +34,7 @@ template <typename GridwiseGemm,
...
@@ -34,7 +34,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
ed305f6b
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
ed305f6b
...
@@ -37,7 +37,7 @@ template <typename GridwiseGemm,
...
@@ -37,7 +37,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r3_for_conv3d
(
kernel_gemm_xdlops_v2r3_for_conv3d
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
ed305f6b
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment