Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
49c39b51
Commit
49c39b51
authored
Nov 05, 2024
by
carlushuang
Browse files
moe pipeline
parent
03c6448b
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
723 additions
and
219 deletions
+723
-219
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+32
-0
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+75
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+135
-21
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+406
-104
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+29
-29
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+46
-10
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
;
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
49c39b51
...
@@ -10,114 +10,134 @@
...
@@ -10,114 +10,134 @@
namespace
ck_tile
{
namespace
ck_tile
{
// fp16
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
// bf16
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
2
>>
;
// fp8
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
2
,
swizzle_factor
>>
;
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
49c39b51
...
@@ -51,10 +51,13 @@ struct WarpGemmAtrributeMfma
...
@@ -51,10 +51,13 @@ struct WarpGemmAtrributeMfma
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
Impl
{}(
c_vec
,
a_vec
,
b_vec
);
Impl
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -111,8 +114,11 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -111,8 +114,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -122,10 +128,33 @@ struct WarpGemmAtrributeMfmaIterateK
...
@@ -122,10 +128,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -194,11 +223,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
...
@@ -194,11 +223,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
// swap A and B
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -255,12 +287,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
...
@@ -255,12 +287,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence
<
2
,
2
>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
// swap A and B
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
Impl
{}(
c_vec
,
b_vec
,
a_vec
,
bool_constant
<
post_nop_
>
{}
);
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -316,9 +351,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -316,9 +351,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence
<
2
,
2
>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -328,10 +366,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
...
@@ -328,10 +366,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -429,8 +491,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -429,8 +491,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
#endif
#endif
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -440,10 +505,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
...
@@ -440,10 +505,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
@@ -518,8 +606,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -518,8 +606,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence
<
0
,
2
>>
;
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
...
@@ -529,10 +620,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
...
@@ -529,10 +620,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
});
});
}
}
template
<
index_t
iKIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
iKIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_assert
(
iKIter
<
kKIter
);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
bool_constant
<
post_nop_
>
{});
//});
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
49c39b51
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,12 +7,45 @@
...
@@ -7,12 +7,45 @@
namespace
ck_tile
{
namespace
ck_tile
{
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum
class
WGAttrCtlEnum
{
Default_
=
0
,
Raw_vvv
=
1
,
// c-vgpr, a-vgpr, b-vgpr
Raw_vaa
=
2
,
// c-vgpr, a-agpr, b-agpr
Raw_vav
=
3
,
// c-vgpr, a-agpr, b-vgpr
Raw_vva
=
4
,
// c-vgpr, a-vgpr, b-agpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
// FP16
// FP16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
{
using
ADataType
=
fp16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
fp16_t
;
using
ADataType
=
fp16_t
;
using
CDataType
=
float
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
@@ -33,16 +66,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -33,16 +66,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8f16"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8f16"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8f16"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8f16"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
{
#if defined(__gfx9__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -52,18 +107,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
...
@@ -52,18 +107,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
{
using
ADataType
=
fp16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
fp16_t
;
using
ADataType
=
fp16_t
;
using
CDataType
=
float
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
...
@@ -84,16 +141,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -84,16 +141,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
{
#if defined(__gfx9__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -103,19 +182,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
...
@@ -103,19 +182,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return
bit_cast
<
CVecType
>
(
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
// Bf16
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
{
using
ADataType
=
bf16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
bf16_t
;
using
ADataType
=
bf16_t
;
using
CDataType
=
float
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
@@ -136,28 +217,50 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -136,28 +217,50 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8bf16_1k"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8bf16_1k"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8bf16_1k"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_32x32x8bf16_1k"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
c_vec
,
0
,
0
,
0
,
0
,
0
);
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -181,18 +284,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
@@ -181,18 +284,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
{
using
ADataType
=
bf16_t
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
bf16_t
;
using
ADataType
=
bf16_t
;
using
CDataType
=
float
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
...
@@ -213,28 +318,50 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -213,28 +318,50 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16bf16_1k"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16bf16_1k"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16bf16_1k"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16bf16_1k"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
c_vec
,
0
,
0
,
0
,
0
,
0
);
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -258,20 +385,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
...
@@ -258,20 +385,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
// FP8
// FP8
template
<
typename
AType_
,
typename
BType_
>
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
{
using
ADataType
=
AType_
;
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
BDataType
=
BType_
;
using
ADataType
=
AType_
;
using
CDataType
=
float
;
using
BDataType
=
BType_
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
...
@@ -292,38 +420,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -292,38 +420,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kCM1PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
template
<
bool
post_nop_
=
false
>
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_fp8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_fp8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
{
DISPATCH_MFMA_
(
"mfma_f32_32x32x16_bf8_bf8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
});
#else
#else
ignore
=
c_vec
;
ck_tile
::
ignore
=
c_vec
;
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
#endif
}
}
}
// c_vec = a_vec * b_vec
// c_vec = a_vec * b_vec
...
@@ -356,20 +566,112 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
@@ -356,20 +566,112 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
});
});
return
c_vec
;
return
c_vec
;
#else
#else
ignore
=
a_vec
;
ck_tile
::
ignore
=
a_vec
;
ignore
=
b_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
return
CVecType
{
0.
f
};
#endif
#endif
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
,
Ctrl_
>
;
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
>
;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
,
Ctrl_
>
;
// int8
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int32_t
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
CVecType
=
ext_vector_t
<
CDataType
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
8
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vvv
)
{
DISPATCH_MFMA_
(
"v_mfma_i32_32x32x16_i8"
,
"+v"
,
"v"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_i32_32x32x16_i8"
,
"+v"
,
"a"
,
"a"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vav
)
{
DISPATCH_MFMA_
(
"v_mfma_i32_32x32x16_i8"
,
"+v"
,
"a"
,
"v"
,
"v"
)
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vva
)
{
DISPATCH_MFMA_
(
"v_mfma_i32_32x32x16_i8"
,
"+v"
,
"v"
,
"a"
,
"v"
)
}
else
{
#if defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
CVecType
c_vec
{
0
};
operator
()(
c_vec
,
a_vec
,
b_vec
);
return
c_vec
;
}
};
#undef DISPATCH_MFMA_
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
49c39b51
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// clang-format off
// fp16
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
// clang-format on
}
// namespace impl
}
// namespace impl
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
49c39b51
...
@@ -31,11 +31,16 @@ struct WarpGemmImpl
...
@@ -31,11 +31,16 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_DEVICE
void
operator
()(
CWarpTensor
&
c
,
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CTensor
>
&&
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
...
@@ -44,18 +49,49 @@ struct WarpGemmImpl
...
@@ -44,18 +49,49 @@ struct WarpGemmImpl
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
);
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
bool_constant
<
post_nop_
>
{}
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
}
CK_TILE_DEVICE
auto
operator
()(
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
number
<
i_subk
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
CWarpTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
const
auto
a_vec
=
a
.
get_thread_buffer
().
template
get_as
<
AVec
>()[
I0
];
const
auto
b_vec
=
b
.
get_thread_buffer
().
template
get_as
<
BVec
>()[
I0
];
auto
c_vec
=
c
.
get_thread_buffer
().
template
get_as
<
CVec
>()[
I0
];
// c_vec += a_vec * b_vec
WarpGemmAttribute
{}(
c_vec
,
a_vec
,
b_vec
,
number
<
i_subk
>
{},
bool_constant
<
post_nop_
>
{});
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
CTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment