Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a199815a
Unverified
Commit
a199815a
authored
Feb 06, 2025
by
Muhammed Emin Ozturk
Committed by
GitHub
Feb 06, 2025
Browse files
Merge branch 'develop' into muozturk_sk_padding
parents
a9a3a3e2
2bef5501
Changes
170
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1632 additions
and
51 deletions
+1632
-51
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1340
-6
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-0
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+140
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+18
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
...tmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+36
-12
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
...ed_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
+12
-4
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+8
-2
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+7
-7
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
+5
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
+5
-1
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
+10
-2
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
+10
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+12
-4
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+7
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+7
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+7
-2
No files found.
include/ck/utility/type_convert.hpp
View file @
a199815a
...
...
@@ -5,15 +5,39 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
namespace
{
namespace
details
{
[[
maybe_unused
]]
__host__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
half2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
[[
maybe_unused
]]
__device__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
return
amd_assembly_pk_add_f16
(
x
,
y
);
}
}
// namespace details
}
// namespace
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
@@ -520,13 +544,51 @@ template <>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f32
=
ck
::
type_convert
<
float
>
(
x_l
);
auto
h_f32
=
ck
::
type_convert
<
float
>
(
x_h
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t
res
=
{
x_h
,
x_l
};
#elif
float2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#else
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
{
l_f32
,
h_f32
};
return
details
::
pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_h
),
type_convert
<
bhalf_t
>
(
x_l
)};
#else
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_l
),
type_convert
<
bhalf_t
>
(
x_h
)};
#endif
return
res
;
}
template
<
>
...
...
@@ -647,6 +709,1278 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
}
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
...
...
include/ck_tile/core.hpp
View file @
a199815a
...
...
@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
...
...
include/ck_tile/core/config.hpp
View file @
a199815a
...
...
@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...
...
include/ck_tile/core/numeric/pk_int4.hpp
0 → 100644
View file @
a199815a
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace
ck_tile
{
// Packed 2xint4
struct
pk_int4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
pk_int4_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
min
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
lowest
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
max
()
{
constexpr
uint8_t
val
=
0b01110111
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t
res
=
{
x_h
,
x_l
};
#elif
fp32x2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
CK_TILE_HOST_DEVICE
fp16x2_t
pk_int4_t_to_halfx2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#elif
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
pk_add_f16
(
bit_cast
<
fp16x2_t
>
(
lo
),
bit_cast
<
fp16x2_t
>
(
SUB
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
pk_int4_t_to_bfloat16x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_h
),
type_convert
<
bf16_t
>
(
x_l
)};
#elif
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_l
),
type_convert
<
bf16_t
>
(
x_h
)};
#endif
return
res
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
a199815a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
View file @
a199815a
...
...
@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
View file @
a199815a
...
...
@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
View file @
a199815a
...
...
@@ -771,4 +771,4 @@
#undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
a199815a
...
...
@@ -41,13 +41,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -58,11 +61,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -72,6 +77,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -106,13 +112,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -123,11 +132,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -137,6 +148,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -171,13 +183,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
// NGCHW requires transpose, we use vector loads and stores params for them
...
...
@@ -189,11 +204,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
...
...
@@ -217,6 +234,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -229,13 +247,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -246,11 +267,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
2
>
,
...
...
@@ -274,6 +297,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
1
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
View file @
a199815a
...
...
@@ -56,11 +56,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bf16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Compute friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -79,7 +81,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -90,11 +92,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_f16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -109,6 +113,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -138,11 +143,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_int8_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
@@ -153,6 +160,7 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
View file @
a199815a
...
...
@@ -40,15 +40,18 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -59,15 +62,18 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
a199815a
...
...
@@ -69,7 +69,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -77,25 +77,25 @@ function(add_instance_library INSTANCE_NAME)
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing batched_gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_streamk_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -109,7 +109,7 @@ function(add_instance_library INSTANCE_NAME)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
source MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
elseif
(
source MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
...
...
@@ -368,7 +368,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
OR gpu_list MATCHES
"gfx95"
)
add_library
(
device_mha_operations
${
CK_DEVICE_MHA_INSTANCES
}
)
set_target_properties
(
device_mha_operations
PROPERTIES
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
View file @
a199815a
...
...
@@ -27,12 +27,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -65,6 +68,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
View file @
a199815a
...
...
@@ -27,12 +27,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -65,6 +68,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
View file @
a199815a
...
...
@@ -26,23 +26,30 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
16
,
16
,
4
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#endif
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -102,6 +109,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
16
,
32
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
View file @
a199815a
...
...
@@ -26,23 +26,30 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#endif
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -90,6 +97,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
a199815a
...
...
@@ -26,18 +26,22 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
#if !defined(CK_USE_AMD_MFMA_GFX950)
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#endif
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
bool
Masking
>
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
...
...
@@ -53,24 +57,28 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
bool
Masking
>
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
Masking
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
void
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
a199815a
...
...
@@ -26,10 +26,12 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
#if !defined(CK_USE_AMD_MFMA_GFX950)
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#endif
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
...
...
@@ -40,11 +42,13 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar|
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector|
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
,
1
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
...
...
@@ -62,7 +66,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
,
1
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
a199815a
...
...
@@ -26,10 +26,12 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
#if !defined(CK_USE_AMD_MFMA_GFX950)
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#endif
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
...
...
@@ -40,11 +42,13 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar|
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector|
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
,
1
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
...
...
@@ -64,7 +68,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
,
1
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
F32
,
F16
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
a199815a
...
...
@@ -26,10 +26,12 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
#if !defined(CK_USE_AMD_MFMA_GFX950)
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#endif
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
...
...
@@ -40,11 +42,13 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
...
...
@@ -60,7 +64,8 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment