Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9c0811f3
Commit
9c0811f3
authored
Sep 24, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
ded0d83d
3528a523
Changes
153
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1590 additions
and
245 deletions
+1590
-245
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+326
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+17
-0
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+349
-74
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+4
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+38
-14
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
...tion/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
+236
-0
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+14
-12
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+164
-9
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+131
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+45
-40
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+27
-24
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+2
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+6
-5
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+4
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+191
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
9c0811f3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
9c0811f3
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
}
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCW_GKXC_NGKW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKW
>
;
}
// 2d
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCSpatial_GKSpatial_NGKSpatial
()
{
return
is_NGCW_GKXC_NGKW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
9c0811f3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
9c0811f3
...
@@ -355,12 +355,39 @@ struct UnaryDivide
...
@@ -355,12 +355,39 @@ struct UnaryDivide
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
};
template
<
>
__host__
__device__
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
half_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
bhalf_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
f8_t
>
(
x_
/
divider_f_
);
};
int32_t
divider_
=
1
;
int32_t
divider_
=
1
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
9c0811f3
...
@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
9c0811f3
...
@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
9c0811f3
...
@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
...
@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
forward_sweep_
(
I0
)
=
true
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
]
;
index_t
tmp
=
0
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
...
...
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
View file @
9c0811f3
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
}
};
};
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
}
};
};
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
}
};
};
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
}
};
};
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"
);
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
k
%
4
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
/
4
],
p_c_thread
);
});
});
}
}
...
...
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
0 → 100644
View file @
9c0811f3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
index_t
NDimSpatial
,
index_t
MPerThread
,
index_t
NPerThread
>
struct
TransformConvNGCHWToNHWGC
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I3
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
static
auto
TransposeStrides
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_strides
)
{
if
constexpr
(
device
::
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
device
::
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides_transposed
;
const
auto
G
=
g_n_c_wis_lengths
[
I0
];
const
auto
C
=
g_n_c_wis_lengths
[
I2
];
g_n_c_wis_strides_transposed
[
I0
]
=
C
;
g_n_c_wis_strides_transposed
[
I1
]
=
g_n_c_wis_strides
[
I1
];
g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
G
*
C
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I5
]
=
G
*
C
;
}
return
g_n_c_wis_strides_transposed
;
}
else
{
// transpose not needed
return
g_n_c_wis_strides
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/amd_smfmac.hpp
View file @
9c0811f3
...
@@ -9,16 +9,18 @@ namespace ck {
...
@@ -9,16 +9,18 @@ namespace ck {
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
struct
intrin_smfmac_f32_16x16x32f16
;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template
<
>
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
#else
ignore
=
reg_a
;
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_b
;
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template
<
>
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
#else
ignore
=
reg_a
;
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_b
;
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template
<
>
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
#else
ignore
=
reg_a
;
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_b
;
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template
<
>
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
#else
ignore
=
reg_a
;
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_b
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
9c0811f3
...
@@ -52,12 +52,28 @@ struct Add
...
@@ -52,12 +52,28 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
a
=
a
+
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
{
float
a_
=
type_convert
<
float
>
(
a
);
float
a_
=
type_convert
<
float
>
(
a
);
...
@@ -112,12 +128,28 @@ struct Mul
...
@@ -112,12 +128,28 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
a
=
a
*
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
{
float
a_
=
type_convert
<
float
>
(
a
);
float
a_
=
type_convert
<
float
>
(
a
);
...
@@ -137,6 +169,16 @@ struct Max
...
@@ -137,6 +169,16 @@ struct Max
float
val
=
NumericLimits
<
float
>::
Lowest
();
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
return
type_convert
<
bhalf_t
>
(
val
);
}
}
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
f8_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
half_t
>
(
val
);
}
else
else
{
{
return
NumericLimits
<
T
>::
Lowest
();
return
NumericLimits
<
T
>::
Lowest
();
...
@@ -154,8 +196,7 @@ struct Max
...
@@ -154,8 +196,7 @@ struct Max
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
if
(
a
<
b
)
...
@@ -171,12 +212,29 @@ struct Max
...
@@ -171,12 +212,29 @@ struct Max
a
=
b
;
a
=
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
if
(
a
<
b
)
...
@@ -197,6 +255,30 @@ struct Max
...
@@ -197,6 +255,30 @@ struct Max
changed
=
true
;
changed
=
true
;
}
}
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
};
struct
Min
struct
Min
...
@@ -209,6 +291,16 @@ struct Min
...
@@ -209,6 +291,16 @@ struct Min
float
val
=
NumericLimits
<
float
>::
Max
();
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
return
type_convert
<
bhalf_t
>
(
val
);
}
}
else
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
half_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
f8_t
>
(
val
);
}
else
else
{
{
return
NumericLimits
<
T
>::
Max
();
return
NumericLimits
<
T
>::
Max
();
...
@@ -227,8 +319,7 @@ struct Min
...
@@ -227,8 +319,7 @@ struct Min
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
if
(
a
>
b
)
...
@@ -244,6 +335,24 @@ struct Min
...
@@ -244,6 +335,24 @@ struct Min
a
=
b
;
a
=
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
{
...
@@ -270,6 +379,30 @@ struct Min
...
@@ -270,6 +379,30 @@ struct Min
changed
=
true
;
changed
=
true
;
}
}
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
};
struct
AMax
struct
AMax
...
@@ -299,6 +432,15 @@ struct AMax
...
@@ -299,6 +432,15 @@ struct AMax
a
=
b
;
a
=
b
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
{
...
@@ -313,6 +455,18 @@ struct AMax
...
@@ -313,6 +455,18 @@ struct AMax
changed
=
true
;
changed
=
true
;
}
}
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
};
template
<
typename
DataType
>
template
<
typename
DataType
>
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
9c0811f3
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
#include <thread>
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -13,6 +14,9 @@ template <typename ADataType,
...
@@ -13,6 +14,9 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
typename
ACCElementOp
=
ck_tile
::
identity
>
...
@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
K
=
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
...
@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
...
@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
}
}
};
};
make_ParallelTensorFunctor
(
f
,
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
c_m_n
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
__global__
void
naive_gemm_kernel
(
ADataType
*
A
,
BDataType
*
B
,
CDataType
*
C
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
strideA
,
ck_tile
::
index_t
strideB
,
ck_tile
::
index_t
strideC
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
idx
/
N
;
// Compute row index
int
col
=
idx
%
N
;
// Compute column index
if
(
row
<
M
&&
col
<
N
)
{
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
acc
+=
static_cast
<
AccDataType
>
(
A
[
row
*
strideA
+
k
])
*
static_cast
<
AccDataType
>
(
B
[
col
*
strideB
+
k
]);
}
C
[
row
*
strideC
+
col
]
=
acc
;
// Store as AccDataType
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
9c0811f3
...
@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
...
@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>>
;
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
OGradDataType
,
...
@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>>
;
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>>
;
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
9c0811f3
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if
masked and
no work to do
.
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
9c0811f3
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
// check early exit
if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
9c0811f3
...
@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
BlockGemmProblem
=
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/gemm.hpp
View file @
9c0811f3
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
9c0811f3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
9c0811f3
...
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
}
else
{
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
}
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
0 → 100644
View file @
9c0811f3
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment