Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
06b891c5
Unverified
Commit
06b891c5
authored
May 20, 2024
by
Illia Silin
Committed by
GitHub
May 20, 2024
Browse files
aggregate device macros in ck_tile config header (#1297)
parent
1274861a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
56 additions
and
56 deletions
+56
-56
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-2
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+4
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
...tion/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
+4
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
.../gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
+2
-3
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+21
-9
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+10
-10
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+10
-14
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
06b891c5
...
...
@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
06b891c5
...
...
@@ -45,8 +45,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
KBatch
=
1
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
06b891c5
...
...
@@ -50,8 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -80,7 +79,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
// Assume B is Col-Major
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
06b891c5
...
...
@@ -34,8 +34,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
...
@@ -48,7 +47,7 @@ __global__ void
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -63,8 +62,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -81,7 +79,7 @@ __global__ void
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
View file @
06b891c5
...
...
@@ -33,8 +33,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
...
...
@@ -49,7 +48,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -64,8 +63,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -84,7 +82,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
View file @
06b891c5
...
...
@@ -38,8 +38,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
...
...
@@ -52,7 +51,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
index_t
BlockSize
,
...
...
include/ck_tile/core/config.hpp
View file @
06b891c5
...
...
@@ -3,6 +3,21 @@
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -109,15 +124,13 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...
...
@@ -137,13 +150,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103
0
__) // for GPU code
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__) // for GPU code
#elif defined(__gfx11__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
06b891c5
...
...
@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
...
...
@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
...
...
@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
...
...
@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
...
...
@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
...
...
@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
06b891c5
...
...
@@ -112,7 +112,7 @@ namespace impl {
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8x4
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
06b891c5
...
...
@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
...
...
@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
...
...
@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
...
...
@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
...
...
@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
...
@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
...
...
@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
...
@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
...
...
@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
...
...
@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment