Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
175a17f8
Commit
175a17f8
authored
Nov 23, 2024
by
Rostyslav Geyyer
Browse files
Merge branch 'gfx950' of
https://github.com/ROCm/composable_kernel-internal
into lwpck-2390
parents
3e520bbd
1504c3e8
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1169 additions
and
259 deletions
+1169
-259
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+3
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+253
-74
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+8
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+128
-60
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+5
-3
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+1
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+78
-0
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+6
-6
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+9
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+2
-0
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+232
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
...e/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
+39
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
...ude/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
+15
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+23
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+50
-20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+44
-19
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+267
-63
No files found.
include/ck/utility/amd_wmma.hpp
View file @
175a17f8
...
...
@@ -9,7 +9,8 @@
// TODO: Add arch limitation
namespace
ck
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
/********************************WAVE32 MODE***********************************************/
...
...
@@ -260,7 +261,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx12_generic__)
#define __gfx12__
#endif
...
...
include/ck/utility/data_type.hpp
View file @
175a17f8
...
...
@@ -3,6 +3,7 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace
ck
{
...
...
@@ -11,8 +12,6 @@ using bhalf_t = ushort;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
struct
e8m0_bexp_t
{
...
...
@@ -53,14 +52,20 @@ inline constexpr auto next_pow2(uint32_t x)
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
<<<<<<<
HEAD
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_t
>::
value
||
is_same
<
T
,
bf8_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
;
=======
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
>>>>>>>
1504
c3e8f552f519e963f3bc1880946c00a7fbc5
}
// vector_type
...
...
@@ -200,16 +205,30 @@ struct scalar_type<int4_t>
#endif
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_
fnuz_
t
>
{
using
type
=
f8_t
;
using
type
=
f8_
fnuz_
t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_
fnuz_
t
>
{
using
type
=
bf8_t
;
using
type
=
bf8_fnuz_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
f8_ocp_t
>
{
using
type
=
f8_ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
};
...
...
@@ -1057,47 +1076,83 @@ struct non_native_vector_base
T
d
[
N
];
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
f8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
bf8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
bf8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
type
=
d1_nnv_t
;
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
{
d1_t
d1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
d1_nnv_t
d1_nnv_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
d1_t
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
{
v
}
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
type
=
d2_t
;
...
...
@@ -1115,10 +1170,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x2_
;
}
...
...
@@ -1135,10 +1191,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x2_
;
}
...
...
@@ -1156,9 +1213,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
type
=
d4_t
;
...
...
@@ -1177,10 +1235,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x4_
;
}
...
...
@@ -1201,10 +1260,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x4_
;
}
...
...
@@ -1226,10 +1286,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
type
=
d8_t
;
...
...
@@ -1249,11 +1310,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x8_
;
}
...
...
@@ -1278,11 +1340,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x8_
;
}
...
...
@@ -1308,11 +1371,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
type
=
d16_t
;
...
...
@@ -1333,12 +1397,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x16_
;
}
...
...
@@ -1367,12 +1431,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x16_
;
}
...
...
@@ -1666,20 +1730,70 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
64
>::
type
;
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
64
>::
type
;
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
// bf8
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
...
...
@@ -1744,7 +1858,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
...
...
@@ -1757,17 +1871,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Min
()
{
return
f8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Max
()
{
return
f8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Lowest
()
{
return
f8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
QuietNaN
()
{
return
f8_
fnuz_
t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
...
...
@@ -1780,13 +1894,59 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Min
()
{
return
bf8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Max
()
{
return
bf8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_
fnuz_
t
Lowest
()
{
return
bf8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
...
...
@@ -1884,6 +2044,7 @@ struct NumericUtils<half_t>
};
template
<
>
<<<<<<<
HEAD
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
...
...
@@ -1894,6 +2055,9 @@ struct NumericUtils<bhalf_t>
template
<
>
struct
NumericUtils
<
f8_t
>
=======
struct
NumericUtils
<
f8_fnuz_t
>
>>>>>>>
1504
c3e8f552f519e963f3bc1880946c00a7fbc5
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
...
...
@@ -1903,7 +2067,7 @@ struct NumericUtils<f8_t>
};
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_
fnuz_
t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
...
...
@@ -1911,6 +2075,21 @@ struct NumericUtils<bf8_t>
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
7
;
};
template
<
>
struct
NumericUtils
<
bf8_ocp_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
15
;
};
template
<
>
struct
NumericUtils
<
f4_t
>
...
...
include/ck/utility/math_v2.hpp
View file @
175a17f8
...
...
@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
...
...
@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
...
...
include/ck/utility/random_gen.hpp
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
// Pseudo random number generator
...
...
@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
std
::
ignore
=
id
;
...
...
include/ck/utility/type_convert.hpp
View file @
175a17f8
...
...
@@ -102,6 +102,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
template
<
>
inline
__host__
__device__
constexpr
f8_ocp_t
type_convert
<
f8_ocp_t
,
int
>
(
int
x
)
{
return
f8_ocp_t
{
type_convert
<
f8_ocp_t
::
data_type
>
(
x
)};
}
template
<
>
inline
__host__
__device__
constexpr
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
int
>
(
int
x
)
{
return
bf8_ocp_t
{
type_convert
<
bf8_ocp_t
::
data_type
>
(
x
)};
}
// Convert X to Y
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
...
...
@@ -165,7 +177,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
...
@@ -191,33 +203,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
...
...
@@ -242,28 +256,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
...
...
@@ -273,7 +291,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -298,32 +316,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -347,44 +367,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
ocp_
t
type_convert
<
f8_
ocp_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -394,26 +429,26 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_
fnuz_
t
>
(
f8x2_
fnuz_
t
x
)
{
#if defined(__gfx94__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_
fnuz_
t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
}
...
...
@@ -430,42 +465,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
}
// convert fp8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
type_convert
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_ocp_t
>
(
x
);
#endif
}
// convert bf8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -475,31 +532,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_fnuz_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_fnuz_t
type_convert
<
bf8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
ocp_
t
type_convert
<
bf8_
ocp_
t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
ocp_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
ocp_
t
>
(
x
);
#endif
}
// convert bf8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
...
...
include/ck_tile/core/config.hpp
View file @
175a17f8
...
...
@@ -11,13 +11,15 @@
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx12_generic__)
#define __gfx12__
#endif
...
...
include/ck_tile/core/tensor/shuffle_tile.hpp
View file @
175a17f8
...
...
@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
}
else
{
// NOT implemented
static_assert
(
false
,
"The shuffle should always happen!"
);
}
}
...
...
include/ck_tile/host.hpp
View file @
175a17f8
...
...
@@ -23,6 +23,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
for
(
index_t
k
=
0
;
k
<
topk
;
k
++
)
{
IndexType
e
=
topk_ids
(
t
,
k
);
WeightType
w
=
weights
(
t
,
k
);
index_t
idx
=
expert_slice_idxs
[
e
];
if
(
idx
>
expert_slices
[
e
]
*
unit_size
-
1
)
{
expert_slices
[
e
]
++
;
index_t
new_size
=
expert_slices
[
e
]
*
unit_size
;
expert_tokens
[
e
].
resize
(
new_size
);
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
{
expert_tokens
[
e
][
i
]
=
num_token
;
expert_token_weights
[
e
][
i
]
=
0
;
}
}
expert_tokens
[
e
][
idx
]
=
t
;
expert_token_weights
[
e
][
idx
]
=
w
;
expert_slice_idxs
[
e
]
++
;
}
}
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
unit_size
);
out_weights
+=
expert_slices
[
e
]
*
unit_size
;
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
}
unit_cnt
*=
unit_size
;
return
;
}
}
// namespace ck_tile
include/ck_tile/ops/common/generic_2d_block_shape.hpp
View file @
175a17f8
...
...
@@ -38,9 +38,7 @@ namespace ck_tile {
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
typename
Vector_
>
// contiguous pixels(vector size) along seq<M, N>)>
struct
Generic2dBlockShape
{
// block size
...
...
@@ -68,10 +66,12 @@ struct Generic2dBlockShape
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
ThreadPerBlock_M
=
Block_M
/
Repeat_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerBlock_N
=
Block_N
/
Repeat_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
BlockSize
=
ThreadPerBlock_M
*
ThreadPerBlock_N
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
View file @
175a17f8
...
...
@@ -230,7 +230,15 @@ struct PageBlockNavigator
CK_TILE_HOST_DEVICE
DataType
*
get_block_ptr
(
index_t
block_index
)
const
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
if
(
block_index
<
num_blocks
)
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
}
else
{
return
nullptr
;
}
}
CK_TILE_HOST_DEVICE
int32_t
get_block_index
(
const
WindowOrigin
&
global_window_origin
)
const
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
175a17f8
...
...
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
!=
0
);
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
void
*
p_moe_buf
;
index_t
tokens
;
index_t
unit_size
;
index_t
num_experts
;
index_t
topk
;
index_t
moe_buf_bytes
;
};
template
<
typename
Problem_
>
struct
MoeSortingKernel
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
IndexType
=
typename
Problem
::
IndexType
;
using
WeightType
=
typename
Problem
::
WeightType
;
typedef
MoeSortingHostArgs
MoeSortingKargs
;
using
Hargs
=
MoeSortingHostArgs
;
struct
Kargs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
void
*
p_moe_buf
;
index_t
tokens
;
index_t
num_experts
;
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
// TODO: assume num-experts not too much
return
dim3
(
1
+
ck_tile
::
integer_divide_ceil
(
h
.
moe_buf_bytes
,
BlockSize
(
h
).
x
*
16
));
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
k
.
p_moe_buf
=
h
.
p_moe_buf
;
k
.
p_total_tokens_post_pad
=
h
.
p_total_tokens_post_pad
;
k
.
tokens
=
h
.
tokens
;
k
.
num_experts
=
h
.
num_experts
;
k
.
moe_buf_bytes
=
h
.
moe_buf_bytes
;
const
auto
blocks
=
BlockSize
(
h
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
return
k
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
}
CK_TILE_DEVICE
void
moe_buf_set_zero_kernel
(
uint8x16_t
*
buf
,
index_t
buf_bytes
)
const
{
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
buf_bytes
/
16
)
{
buf
[
offset
]
=
uint8x16_t
{
0
};
}
}
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens_per_thread
,
const
index_t
numel
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
start_idx
=
tid
*
tokens_per_thread
;
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
}
}
// __syncthreads();
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
}
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
{
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
if
(
blockIdx
.
x
>
0
)
{
if
(
kargs
.
p_moe_buf
)
{
moe_buf_set_zero_kernel
(
reinterpret_cast
<
uint8x16_t
*>
(
kargs
.
p_moe_buf
),
kargs
.
moe_buf_bytes
);
}
return
;
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens_per_thread
,
numel
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* p_sorted_token_ids,
// WeightType* p_sorted_weights,
// index_t* p_sorted_expert_ids,
// index_t* p_total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
// };
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
MoeSortingPolicy
{
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
struct
MoeSortingProblem
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
175a17f8
...
...
@@ -115,12 +115,22 @@ struct GemmKernel
}
}();
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
...
...
@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
// clang-format on
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
...
...
@@ -171,18 +191,28 @@ struct GemmKernel
}
}();
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
// clang-format on
auto
c_block_window
=
make_tile_window
(
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
c_block_tile
);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
175a17f8
...
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
175a17f8
...
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
...
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B DRAM tile window for load
auto
b_copy_dram_window
=
...
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
...
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
index_t
iCounter
=
num_loop
-
1
;
...
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
175a17f8
...
...
@@ -11,6 +11,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
template <typename Problem>
...
...
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
// fake XOR
template
<
typename
Problem
>
...
...
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment