Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ba504b6
Commit
9ba504b6
authored
Feb 07, 2025
by
ThomasNing
Browse files
merge with the develop support the fp8 with computev4
parents
e3402c93
f49de496
Changes
198
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
815 additions
and
154 deletions
+815
-154
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+140
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+18
-1
include/ck_tile/core/utility/transpose_vectors.hpp
include/ck_tile/core/utility/transpose_vectors.hpp
+73
-43
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+13
-7
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+3
-2
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+2
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
...tmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+1
-1
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+4
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+7
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+70
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+236
-55
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+0
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
...rary/reference_tensor_operation/cpu/reference_mx_gemm.hpp
+178
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+36
-12
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
...ed_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
+12
-4
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+8
-2
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+7
-7
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
+5
-1
No files found.
include/ck_tile/core/numeric/pk_int4.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace
ck_tile
{
// Packed 2xint4
struct
pk_int4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
pk_int4_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
min
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
lowest
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
max
()
{
constexpr
uint8_t
val
=
0b01110111
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t
res
=
{
x_h
,
x_l
};
#elif
fp32x2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
CK_TILE_HOST_DEVICE
fp16x2_t
pk_int4_t_to_halfx2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#elif
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
pk_add_f16
(
bit_cast
<
fp16x2_t
>
(
lo
),
bit_cast
<
fp16x2_t
>
(
SUB
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
pk_int4_t_to_bfloat16x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_h
),
type_convert
<
bf16_t
>
(
x_l
)};
#elif
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_l
),
type_convert
<
bf16_t
>
(
x_h
)};
#endif
return
res
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
include/ck_tile/core/utility/transpose_vectors.hpp
View file @
9ba504b6
...
...
@@ -68,52 +68,82 @@ struct transpose_vectors
}
else
if
constexpr
(
sizeof
(
S
)
==
1
)
{
static_assert
((
NX
%
4
==
0
&&
NY
%
4
==
0
),
"wrong!"
);
static_assert
((
(
NX
%
4
==
0
&&
NY
%
4
==
0
)
||
(
NX
%
2
==
0
&&
NY
%
2
==
0
))
,
"wrong!"
);
using
S4
=
array
<
S
,
4
>
;
// typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
4
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
4
>
{}([
&
](
auto
ix
)
{
// 4 int8x4 data from vx_tuple
const
int32_t
x_s4_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_2
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I2
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_3
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I3
].
template
get_as
<
S4
>()[
iy
/
I4
]);
// transpose
int32_t
t_s4_0
,
t_s4_1
;
int32_t
y_s4_0
,
y_s4_1
,
y_s4_2
,
y_s4_3
;
constexpr
int32_t
m0
=
0x05010400
;
constexpr
int32_t
m1
=
0x05040100
;
constexpr
int32_t
m2
=
0x07060302
;
constexpr
int32_t
m3
=
0x07030602
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m0
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m0
);
y_s4_0
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_1
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m3
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m3
);
y_s4_2
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_3
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
// 4 int8x4 data from vy_tuple
vy_tuple
(
iy
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_1
);
vy_tuple
(
iy
+
I2
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_2
);
vy_tuple
(
iy
+
I3
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_3
);
using
S2
=
array
<
S
,
2
>
;
// typename array<S, 4>::type;
if
constexpr
(
NX
%
4
==
0
&&
NY
%
4
==
0
)
{
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
4
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
4
>
{}([
&
](
auto
ix
)
{
// 4 int8x4 data from vx_tuple
const
int32_t
x_s4_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_2
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I2
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_3
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I3
].
template
get_as
<
S4
>()[
iy
/
I4
]);
// transpose
int32_t
t_s4_0
,
t_s4_1
;
int32_t
y_s4_0
,
y_s4_1
,
y_s4_2
,
y_s4_3
;
constexpr
int32_t
m0
=
0x05010400
;
constexpr
int32_t
m1
=
0x05040100
;
constexpr
int32_t
m2
=
0x07060302
;
constexpr
int32_t
m3
=
0x07030602
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
// 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits
// first)
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m0
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m0
);
y_s4_0
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_1
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m3
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m3
);
y_s4_2
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_3
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
// 4 int8x4 data from vy_tuple
vy_tuple
(
iy
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_1
);
vy_tuple
(
iy
+
I2
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_2
);
vy_tuple
(
iy
+
I3
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_3
);
});
});
});
}
else
if
constexpr
(
NX
%
2
==
0
&&
NY
%
2
==
0
)
{
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
ix
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
iy
)
{
const
int16_t
x_s2_0
=
bit_cast
<
int16_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S2
>()[
iy
/
I2
]);
const
int16_t
x_s2_1
=
bit_cast
<
int16_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S2
>()[
iy
/
I2
]);
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
const
int32_t
x0_32
=
static_cast
<
int32_t
>
(
x_s2_0
&
0xFFFF
);
const
int32_t
x1_32
=
static_cast
<
int32_t
>
(
x_s2_1
&
0xFFFF
);
const
int32_t
y_s2_0
=
__builtin_amdgcn_perm
(
x1_32
,
x0_32
,
m0
);
const
int32_t
y_s2_1
=
__builtin_amdgcn_perm
(
x1_32
,
x0_32
,
m1
);
vy_tuple
(
iy
).
template
get_as
<
S2
>()[
ix
/
I2
]
=
bit_cast
<
S2
>
(
static_cast
<
int16_t
>
(
y_s2_0
&
0xFFFF
));
vy_tuple
(
iy
+
I1
).
template
get_as
<
S2
>()[
ix
/
I2
]
=
bit_cast
<
S2
>
(
static_cast
<
int16_t
>
(
y_s2_1
&
0xFFFF
));
});
});
}
}
else
{
...
...
include/ck_tile/host/check_err.hpp
View file @
9ba504b6
...
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
...
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
...
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
...
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
...
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
...
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
...
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
9ba504b6
...
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
9ba504b6
...
...
@@ -77,6 +77,7 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template
<
typename
ODataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
constexpr
index_t
MaxVectorStoreSize
=
16
;
...
...
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kNPerIteration
,
GetVectorSizeC
(),
GetVectorSizeC
<
ODataType
>
(),
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
View file @
9ba504b6
...
...
@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
View file @
9ba504b6
...
...
@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
View file @
9ba504b6
...
...
@@ -771,4 +771,4 @@
#undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
9ba504b6
...
...
@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static
constexpr
index_t
KPack
=
8
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
9ba504b6
...
...
@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
if
constexpr
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
k_batch
!=
1
)
...
...
@@ -240,7 +240,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
N
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -255,7 +255,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
M
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
EpiloguePipeline
::
GetVectorSizeC
()
>
{},
number
<
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -589,7 +589,9 @@ struct GemmKernel
}
else
{
if
constexpr
(
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
9ba504b6
...
...
@@ -3,6 +3,9 @@
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...
...
@@ -78,11 +81,63 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST
static
std
::
string
Print
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
@@ -95,29 +150,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
9ba504b6
...
...
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
...
@@ -167,11 +167,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -215,25 +226,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -249,19 +294,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -277,12 +348,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
@@ -354,11 +445,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -402,25 +504,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -434,19 +569,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -459,12 +620,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
9ba504b6
...
...
@@ -323,7 +323,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceMXGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_m_k_
{
a_m_k
},
a_m_kblock_scales_
{
a_m_kblock_scales
},
b_k_n_
{
b_k_n
},
b_kblock_n_scales_
{
b_kblock_n_scales
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceMXGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ComputeTypeA
,
ComputeTypeB
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
;
Tensor
<
ComputeTypeA
>
a_m_k_scaled
(
arg
.
a_m_k_
.
mDesc
);
Tensor
<
ComputeTypeB
>
b_k_n_scaled
(
arg
.
b_k_n_
.
mDesc
);
const
auto
M
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
arg
.
b_k_n_
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
auto
SCALE_BLOCK
=
K
/
arg
.
a_m_kblock_scales_
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k_scaled
(
m
,
k
)
=
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_k_
(
m
,
k
))
*
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_kblock_scales_
(
m
,
k
/
SCALE_BLOCK
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n_scaled
(
k
,
n
)
=
type_convert
<
ComputeTypeB
>
(
arg
.
b_k_n_
(
k
,
n
))
*
type_convert
<
ComputeTypeB
>
(
arg
.
b_kblock_n_scales_
(
k
/
SCALE_BLOCK
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k_scaled
,
b_k_n_scaled
,
arg
.
c_m_n_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ref_invoker
.
Run
(
ref_argument
);
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_m_k
,
a_m_kblock_scales
,
b_k_n
,
b_kblock_n_scales
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMXGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
9ba504b6
...
...
@@ -41,13 +41,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -58,11 +61,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -72,6 +77,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -106,13 +112,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -123,11 +132,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -137,6 +148,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -171,13 +183,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
// NGCHW requires transpose, we use vector loads and stores params for them
...
...
@@ -189,11 +204,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
...
...
@@ -217,6 +234,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
@@ -229,13 +247,16 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
// clang-format on
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
...
...
@@ -246,11 +267,13 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#else
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
2
>
,
...
...
@@ -274,6 +297,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
1
>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
View file @
9ba504b6
...
...
@@ -56,11 +56,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bf16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Compute friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -79,7 +81,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -90,11 +92,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_f16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -109,6 +113,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -138,11 +143,13 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_int8_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
@@ -153,6 +160,7 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
View file @
9ba504b6
...
...
@@ -40,15 +40,18 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -59,15 +62,18 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
9ba504b6
...
...
@@ -69,7 +69,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -77,25 +77,25 @@ function(add_instance_library INSTANCE_NAME)
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing batched_gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_streamk_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -109,7 +109,7 @@ function(add_instance_library INSTANCE_NAME)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
source MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
elseif
(
source MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
...
...
@@ -368,7 +368,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
OR gpu_list MATCHES
"gfx95"
)
add_library
(
device_mha_operations
${
CK_DEVICE_MHA_INSTANCES
}
)
set_target_properties
(
device_mha_operations
PROPERTIES
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
View file @
9ba504b6
...
...
@@ -27,12 +27,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
#else
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -65,6 +68,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v2
>
#endif
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>
;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment