Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d480a5a6
Unverified
Commit
d480a5a6
authored
Feb 03, 2025
by
Max Podkorytov
Committed by
GitHub
Feb 03, 2025
Browse files
Merge branch 'develop' into ck-flex
parents
bca939ce
9c5b2f39
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
965 additions
and
386 deletions
+965
-386
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+3
-1
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+3
-3
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+6
-6
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+6
-1
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-7
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+6
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+5
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+3
-3
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+15
-11
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+5
-1
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+18
-23
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+8
-8
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+10
-4
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+316
-49
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+32
-12
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+146
-151
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+97
-4
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+6
-6
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+34
-48
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+237
-46
No files found.
include/ck/utility/env.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once
#include <cstdlib>
...
...
@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
}
// namespace ck
#endif
include/ck/utility/functional.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
...
...
include/ck/utility/functional4.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
...
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
}
};
...
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
ck
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
};
...
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
...
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Y
>
(
y
));
}
}
// namespace ck
...
...
include/ck/utility/integral_constant.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
include/ck/utility/is_detected.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
value_t
=
integral_constant
<
bool
,
false
>
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
value_t
=
integral_constant
<
bool
,
true
>
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
using
is_pack2_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
using
is_pack4_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
using
is_pack8_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/loop_scheduler.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#pragma once
...
...
@@ -25,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
LoopScheduler
&
s
)
{
switch
(
s
)
...
...
@@ -35,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return
os
;
}
#endif
include/ck/utility/magic_division.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,6 +9,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace
ck
{
// magic number division
...
...
include/ck/utility/math_v2.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
...
include/ck/utility/random_gen.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using
uint8_t
=
unsigned
char
;
using
uint16_t
=
unsigned
short
;
using
uint32_t
=
unsigned
int
;
#endif
namespace
ck
{
// Pseudo random number generator
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
...
@@ -25,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -40,15 +45,14 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
std
::
ignore
=
id
;
std
::
ignore
=
val
;
std
::
ignore
=
seed
;
ck
::
ignore
=
id
;
ck
::
ignore
=
val
;
ck
::
ignore
=
seed
;
return
0
;
}
...
...
include/ck/utility/sequence.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
...
...
@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os
<<
S
::
At
(
S
::
Size
()
-
ck
::
Number
<
1
>
{}).
value
<<
"}"
;
return
os
;
}
#endif
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...
...
@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
Y
a
,
const
Tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
...
...
@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
...
...
include/ck/utility/tuple.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
ck
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
std
::
forward
(
x
.
mData
);
return
ck
::
forward
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
...
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
ck
::
forward
<
Y
>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
ck
::
forward
<
Ys
>
(
ys
)...)
{
}
...
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
ck
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
...
...
include/ck/utility/tuple_helper.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#endif
namespace
ck
{
...
...
@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#endif
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
#ifndef CK_CODE_GEN_RTC
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
#endif
}
template
<
index_t
depth
=
0
,
typename
T
>
...
...
include/ck/utility/type.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
namespace
ck
{
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template
<
class
T
>
\
struct
name
:
bool_constant
<
__
##
name
(
T
)
>
\
{
\
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template
<
class
T
,
class
U
>
\
struct
name
:
bool_constant
<
__
##
name
(
T
,
U
)
>
\
{
\
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template
<
class
...
Ts
>
\
struct
name
:
bool_constant
<
__
##
name
(
Ts
...)
>
\
{
\
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_reference
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
volatile
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
volatile
>
{
typedef
T
type
;
};
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
class
T
>
struct
is_const
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
>
struct
is_const
<
const
T
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
// template <typename T>
// T&& declval() noexcept;
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
template
<
class
...
>
using
void_t
=
void
;
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_class_v
;
using
std
::
is_const_v
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_reference_v
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable_v
;
using
std
::
is_unsigned
;
using
std
::
remove_const_t
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
void_t
;
#endif
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
typename
X
>
struct
is_same
<
X
,
X
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_floating_point
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_floating_point
<
float
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
long
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_integral
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_integral
<
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
signed
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
wchar_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char16_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char32_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
bool
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
d480a5a6
...
...
@@ -52,10 +52,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -63,13 +63,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstY
=
ck
::
remove_const_t
<
Y
>
;
using
NonConstX
=
ck
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
...
...
@@ -149,7 +149,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -211,7 +211,11 @@ template <>
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -251,7 +255,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
...
...
@@ -265,7 +274,11 @@ template <>
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -307,7 +320,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
...
...
@@ -629,20 +647,22 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
#endif
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#
define CK_TILE_MAX_RANK 5
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#
include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
typename
CLayout_
,
index_t
kBlockSize_
,
index_t
kM_
,
index_t
kN_
,
index_t
kMWave_
,
index_t
kNWave_
,
index_t
kMPerXdl_
,
index_t
kNPerXdl_
,
index_t
kKPerXdl_
,
bool
isCTransposed_
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kMPerBlock
=
kM_
;
static
constexpr
index_t
kNPerBlock
=
kN_
;
static
constexpr
index_t
kMWave
=
kMWave_
;
static
constexpr
index_t
kNWave
=
kNWave_
;
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
kMWave
=
Problem
::
kMWave
;
static
constexpr
index_t
kNWave
=
Problem
::
kNWave
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
static
constexpr
index_t
kMPerIteration
=
kMPerXdl
*
kMWave
;
static
constexpr
index_t
kNPerIteration
=
kNPerXdl
*
kNWave
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
constexpr
index_t
MaxVectorStoreSize
=
16
;
return
MaxVectorStoreSize
/
sizeof
(
ODataType
);
}
template
<
typename
OAccTi
le
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
template
<
typename
Prob
le
m
>
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeLdsBlockDescriptor
(
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
make_tuple
(
number
<
kNWave
*
kNPerXdl
>
{},
number
<
1
>
{}));
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
kMWave
*
kMPerXdl
>
{}));
}
else
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
)
);
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
ODramWindowTmp
,
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
kMWave
*
kNWave
*
kMPerXdl
*
kNPerXdl
*
sizeof
(
ODataType
);
}
template
<
typename
ODramWindow
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindow
&
out_dram_window
,
const
OAccTile
&
o_acc_tile
,
void
*
p_smem
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
const
index_t
iMWarp
=
get_warp_id
()
/
kNWave
;
const
index_t
iNWarp
=
get_warp_id
()
-
iMWarp
*
kNWave
;
constexpr
auto
lds_block_desc
=
MakeLdsBlockDescriptor
<
Problem
>
();
auto
o_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
static_cast
<
ODataType
*>
(
p_smem
),
lds_block_desc
);
auto
in_lds_window
=
make_tile_window
(
o_lds_block
,
make_tuple
(
number
<
kMPerXdl
>
{},
number
<
kNPerXdl
>
{}),
{
number
<
kMPerXdl
>
{}
*
iMWarp
,
number
<
kNPerXdl
>
{}
*
iNWarp
});
auto
out_lds_window
=
make_tile_window
(
o_lds_block
,
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
{
0
,
0
});
using
SFC
=
space_filling_curve
<
sequence
<
kMPerBlock
,
kNPerBlock
>
,
sequence
<
0
,
1
>
,
sequence
<
kMPerXdl
*
kMWave
,
kNPerXdl
*
kNWave
>>
;
constexpr
index_t
num_access
=
SFC
::
get_num_of_access
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kNPerIteration
,
GetVectorSizeC
(),
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
CWarpTensor
c_warp_in_tensor
;
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
constexpr
auto
idx_y_start
=
SFC
::
get_index
(
iAccess
);
constexpr
auto
mIter
=
number
<
idx_y_start
.
at
(
number
<
0
>
{})
/
(
kMPerXdl
*
kMWave
)
>
{};
constexpr
auto
nIter
=
number
<
idx_y_start
.
at
(
number
<
1
>
{})
/
(
kNPerXdl
*
kNWave
)
>
{};
c_warp_in_tensor
.
get_thread_buffer
()
=
o_acc_tile
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
const
auto
c_warp_in_tensor_casted
=
cast_tile
<
ODataType
>
(
c_warp_in_tensor
);
block_sync_lds
();
store_tile
(
in_lds_window
,
c_warp_in_tensor_casted
);
block_sync_lds
();
const
auto
c_out_tensor
=
load_tile
(
make_tile_window
(
out_lds_window
,
dram_tile_distribution
));
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
store_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
else
{
update_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
update_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
buffer_store_fence
();
}
else
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
if
constexpr
(
iAccess
!=
num_access
-
1
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
constexpr
auto
step
=
SFC
::
get_forward_step
(
iAccess
);
move_tile_window
(
out_dram_window
,
{
step
.
at
(
number
<
0
>
{}),
step
.
at
(
number
<
1
>
{})});
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
});
}
};
}
// namespace ck_tile
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
...
...
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
template
<
typename
AccDataType_
,
typename
ODataType_
,
typename
CLayout_
,
bool
kPadM_
,
bool
kPadN_
,
index_t
kMPerXdl_
,
index_t
kNPerXdl_
,
index_t
kKPerXdl_
,
bool
isCTransposed_
,
bool
UseRawStore_
=
true
>
struct
DefaultGemm2DEpilogueProblem
:
public
Default2DEpilogueProblem
<
AccDataType_
,
ODataType_
,
kPadM_
,
kPadN_
,
UseRawStore_
>
{
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Default2DEpilogue
{
...
...
@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
,
void
*
=
nullptr
)
{
// TODO: this is ugly
...
...
@@ -71,4 +92,76 @@ struct Default2DEpilogue
}
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DefaultGemm2DEpilogue
:
public
Default2DEpilogue
<
Problem_
,
Policy_
>
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
d480a5a6
...
...
@@ -70,7 +70,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
*
batch_count
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
)
,
batch_count
,
KBatch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
...
@@ -101,14 +101,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_
k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
auto
i_
splitk
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_
split
k
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
...
...
@@ -128,7 +128,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
d480a5a6
...
...
@@ -75,12 +75,12 @@ struct GemmKernel
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmKernelArgs
{
...
...
@@ -93,7 +93,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
KB
atch
;
index_t
k_b
atch
;
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
...
@@ -121,7 +121,7 @@ struct GemmKernel
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
KB
atch
*
K1
;
const
index_t
K_t
=
kargs
.
k_b
atch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -142,13 +142,13 @@ struct GemmKernel
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
KB
atch
-
1
))
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
k_b
atch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KB
atch
-
1
);
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
k_b
atch
-
1
);
}
}
...
...
@@ -159,14 +159,10 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
if
constexpr
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
KB
atch
!=
1
)
if
(
kargs
.
k_b
atch
!=
1
)
{
std
::
cerr
<<
"Conditions not met for Kbatch >1 !"
<<
std
::
endl
;
return
false
;
...
...
@@ -182,7 +178,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -197,7 +193,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
M
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -213,7 +209,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
N
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -228,7 +224,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -244,7 +240,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
N
%
Epilogue
Pipeline
::
Get
VectorSizeC
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -259,7 +255,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
M
%
Epilogue
Pipeline
::
Get
VectorSizeC
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -275,14 +271,6 @@ struct GemmKernel
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -290,7 +278,7 @@ struct GemmKernel
a_ptr
,
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -299,7 +287,7 @@ struct GemmKernel
a_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -311,7 +299,7 @@ struct GemmKernel
b_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -320,7 +308,7 @@ struct GemmKernel
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -333,7 +321,7 @@ struct GemmKernel
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
Gemm
Pipeline
::
VectorSizeC
>
{},
number
<
Epilogue
Pipeline
::
Get
VectorSizeC
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -501,22 +489,14 @@ struct GemmKernel
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
}
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
...
...
@@ -531,14 +511,20 @@ struct GemmKernel
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
/**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template
<
typename
BlockGemmShapeType
>
struct
GemmTile2DPartitioner
{
...
...
@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief Returns 3D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
([[
maybe_unused
]]
index_t
M
,
[[
maybe_unused
]]
index_t
N
)
noexcept
;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
return
dim3
(
GridDimX
,
GridDimY
,
1
);
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
...
...
@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
/**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
);
...
...
@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
};
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/
template
<
typename
BlockGemmShape
Type
>
template
<
typename
BlockGemmShape
_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
Type
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief delete default ctr with no any object */
constexpr
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief constructs an object that does contain a N value. */
constexpr
GemmTile1DPartitioner
(
index_t
N
)
noexcept
{
N_
=
N
;
}
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief Returns 1D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
/**
* @brief Construct a new GemmTile1DPartitioner object.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
([[
maybe_unused
]]
index_t
M
,
index_t
N
)
noexcept
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
N_
=
N
;
}
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
noexcept
->
index_t
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
GridDimX
*
GridDimY
;
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
NBlock
=
GetN
Block
(
N_
);
const
index_t
NBlock
s
=
integer_divide_ceil
(
N_
,
NPer
Block
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
(
iM
)
*
NBlock
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
s
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
iM
*
NBlock
s
);
return
make_tuple
(
iM
,
iN
);
}
...
...
@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template
<
typename
Partitioner
Fn
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Partitioner
Fn
>{}
>>
template
<
typename
Tile
Partitioner
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Tile
Partitioner
>{}
>>
struct
OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
* @param [in] block_start Workgroup offset.
* @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/
[[
nodiscard
]]
CK_TILE_DEVICE
static
constexpr
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
N
)
noexcept
[[
nodiscard
]]
CK_TILE_DEVICE
static
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
M
,
index_t
N
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
[
iM
,
iN
]
=
Partitioner
Fn
(
N
)
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
const
auto
[
iM
,
iN
]
=
Tile
Partitioner
{
M
,
N
}
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
return
make_tuple
(
iM
,
iN
);
}
};
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template
<
typename
BlockGemmShapeType
,
index_t
GroupNum
,
index_t
M01
>
struct
GemmSpatiallyLocalTilePartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShapeType
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
(
index_t
M_
,
index_t
N_
)
noexcept
:
M
(
M_
),
N
(
N_
)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
const
index_t
GridDimX
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
index_t
GridDimY
=
integer_divide_ceil
(
N
,
NPerBlock
);
return
GridDimX
*
GridDimY
;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
auto
GetOutputTileIndex
(
index_t
block_1d_id
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
M0
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
integer_divide_ceil
(
N
,
NPerBlock
);
if
(
M0
==
1
)
{
return
make_tuple
(
0
,
block_1d_id
);
}
else
if
(
N0
==
1
)
{
return
make_tuple
(
block_1d_id
,
0
);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const
auto
group_size
=
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
const
auto
big_group_num
=
GroupNum
-
(
group_size
*
GroupNum
-
M0
*
N0
);
const
auto
group_id_y
=
block_1d_id
/
GroupNum
;
const
auto
group_id_x
=
block_1d_id
-
group_id_y
*
GroupNum
;
const
auto
remap_block_1d_id
=
group_id_x
<=
big_group_num
?
group_id_x
*
group_size
+
group_id_y
:
group_id_x
*
group_size
+
big_group_num
-
group_id_x
+
group_id_y
;
const
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
index_t
idx_N0
=
remap_block_1d_id
-
idx_M0
*
N0
;
const
index_t
M0_tmp
=
M0
/
M01
;
const
index_t
M0_mod_M01
=
M0
-
M0_tmp
*
M01
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0_mod_M01
)
?
M01
:
M0_mod_M01
;
const
index_t
idx_M00
=
idx_M0
/
M01
;
const
index_t
idx_M01
=
idx_M0
-
idx_M00
*
M01
;
const
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const
index_t
N_out
=
idx_N0_M01_local
/
M01_adapt
;
const
index_t
idx_loc_mod_M01
=
idx_N0_M01_local
-
N_out
*
M01_adapt
;
return
make_tuple
(
idx_loc_mod_M01
+
idx_M00
*
M01
,
N_out
);
}
}
private:
index_t
M
;
index_t
N
;
};
}
// namespace ck_tile
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment