Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
49c39b51
Commit
49c39b51
authored
Nov 05, 2024
by
carlushuang
Browse files
moe pipeline
parent
03c6448b
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
723 additions
and
219 deletions
+723
-219
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+32
-0
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+75
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+135
-21
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+406
-104
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+29
-29
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+46
-10
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
;
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
49c39b51
...
...
@@ -10,114 +10,134 @@
namespace
ck_tile
{
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
49c39b51
...
...
@@ -51,10 +51,13 @@ struct WarpGemmAtrributeMfma
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
Impl
{}(
c_vec
,
a_vec
,
b_vec
);
Impl
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -111,8 +114,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -122,10 +128,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -194,11 +223,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -255,12 +287,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
// c_vec = a_vec * b_vec
...
...
@@ -316,9 +351,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -328,10 +366,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -429,8 +491,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence
<
0
,
2
>>
;
#endif
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -440,10 +505,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
@@ -518,8 +606,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
...
@@ -529,10 +620,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
49c39b51
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
}
// namespace impl
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
49c39b51
...
...
@@ -31,11 +31,16 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_DEVICE
void
operator
()(
CWarpTensor
&
c
,
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
...
...
@@ -44,18 +49,49 @@ struct WarpGemmImpl
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
CK_TILE_DEVICE
auto
operator
()(
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
number
<
i_subk
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
CWarpTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
const
auto
a_vec
=
a
.
get_thread_buffer
().
template
get_as
<
AVec
>()[
I0
];
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
number
<
i_subk
>
{},
bool_constant
<
post_nop_
>
{});
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
CTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment