Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8b49f207
Unverified
Commit
8b49f207
authored
Jan 07, 2025
by
Max Podkorytov
Committed by
GitHub
Jan 07, 2025
Browse files
Merge branch 'develop' into fa-h512
parents
0d59f474
a6b761c3
Changes
262
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
198 additions
and
181 deletions
+198
-181
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+18
-6
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+22
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+37
-0
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+1
-1
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+4
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+14
-1
include/ck_tile/README.md
include/ck_tile/README.md
+3
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-1
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+2
-2
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+3
-3
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+11
-151
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-1
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-1
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+27
-4
No files found.
include/ck/utility/amd_ck_fp8.hpp
View file @
8b49f207
...
@@ -18,6 +18,20 @@
...
@@ -18,6 +18,20 @@
#define CK_USE_OCP_FP8 0
#define CK_USE_OCP_FP8 0
#endif
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
f8_fnuz_t
=
_BitInt
(
8
);
...
@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
}
}
typename
__hip_internal
::
conditional
<
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
...
@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
__hip_internal
::
conditional
<
using
T_bitwise
=
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
8b49f207
...
@@ -4,13 +4,34 @@
...
@@ -4,13 +4,34 @@
#ifndef CK_AMD_INLINE_ASM_HPP
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
#include "c_style_pointer_cast.hpp"
#include "data_type.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace
ck
{
namespace
ck
{
inline
__device__
int
amd_assembly_and_or_b32
(
int
a
,
int
b
,
int
d
)
{
int
c
;
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
d
));
return
c
;
}
inline
__device__
half2_t
amd_assembly_pk_fma_f16
(
half2_t
a
,
half2_t
b
,
half2_t
c
)
{
half2_t
d
;
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
return
d
;
}
inline
__device__
half2_t
amd_assembly_pk_add_f16
(
half2_t
a
,
half2_t
b
)
{
half2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
return
c
;
}
// c0 += inner_product(a, b0)
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
include/ck/utility/data_type.hpp
View file @
8b49f207
...
@@ -12,6 +12,17 @@ using bhalf_t = ushort;
...
@@ -12,6 +12,17 @@ using bhalf_t = ushort;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
operator
float
()
const
{
return
static_cast
<
int8_t
>
(
data
);
}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
{
// Precondition: x > 1.
// Precondition: x > 1.
...
@@ -165,6 +176,13 @@ struct scalar_type<int4_t>
...
@@ -165,6 +176,13 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
struct
scalar_type
<
f8_fnuz_t
>
{
{
...
@@ -1044,6 +1062,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
...
@@ -1044,6 +1062,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using
type
=
bf8_ocp_t
::
data_type
;
using
type
=
bf8_ocp_t
::
data_type
;
};
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
struct
non_native_vector_base
<
T
,
T
,
...
@@ -1163,6 +1187,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
...
@@ -1163,6 +1187,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
// non-native vector_type implementation
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
...
@@ -1871,6 +1903,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
...
@@ -1871,6 +1903,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
8b49f207
...
@@ -54,7 +54,8 @@ struct DynamicBuffer
...
@@ -54,7 +54,8 @@ struct DynamicBuffer
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
{
...
@@ -195,7 +196,8 @@ struct DynamicBuffer
...
@@ -195,7 +196,8 @@ struct DynamicBuffer
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
{
...
...
include/ck/utility/math_v2.hpp
View file @
8b49f207
...
@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
...
@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template
<
>
template
<
>
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
{
{
return
__hneg
(
x
);
return
__hneg
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
include/ck/utility/static_buffer.hpp
View file @
8b49f207
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
include/ck/utility/type_convert.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
...
@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif
#endif
}
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f32
=
ck
::
type_convert
<
float
>
(
x_l
);
auto
h_f32
=
ck
::
type_convert
<
float
>
(
x_h
);
return
{
l_f32
,
h_f32
};
}
template
<
>
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
{
...
...
include/ck_tile/README.md
View file @
8b49f207
...
@@ -45,5 +45,8 @@ our implementation of different device operators.
...
@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]**
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples
## examples
currently we put all ck_tile related example under
[
/example/ck_tile
](
/example/ck_tile/
)
folder. Please check each example's subfolder.
currently we put all ck_tile related example under
[
/example/ck_tile
](
/example/ck_tile/
)
folder. Please check each example's subfolder.
include/ck_tile/core.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -54,6 +54,7 @@
...
@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
8b49f207
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert
(
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
...
...
include/ck_tile/core/container/meta_data_buffer.hpp
View file @
8b49f207
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
{
{
constexpr
index_t
size
=
sizeof
(
T
);
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
auto
tmp
=
ck_tile
::
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
{
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
data
=
bit_cast
<
T
>
(
tmp
);
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
}
}
return
data
;
return
data
;
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
auto
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
return
data
;
return
data
;
}
}
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
8b49f207
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
{
...
...
include/ck_tile/host.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/host/arg_parser.hpp
View file @
8b49f207
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
* a host side utility, arg parser for
* a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ...
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
*/
class
ArgParser
class
ArgParser
{
{
public:
public:
class
Arg
class
Arg
{
{
...
@@ -187,6 +190,45 @@ class ArgParser
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
return
value
;
}
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
8b49f207
...
@@ -97,9 +97,9 @@ template <typename ADataType,
...
@@ -97,9 +97,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
errC
=
hipMemcpy
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
...
@@ -191,9 +125,9 @@ template <typename ADataType,
...
@@ -191,9 +125,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t
batch_stride_C
,
index_t
batch_stride_C
,
index_t
batch_count
)
index_t
batch_count
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/common.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/elementwise.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/epilogue.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
8b49f207
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
...
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
}
template
<
typename
OAccTile
>
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
{
...
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
...
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
}
}
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
...
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
...
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
if
constexpr
(
kPadM
||
kPadN
)
{
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
buffer_store_fence
();
}
}
else
else
{
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
}
}
};
};
...
...
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment