Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
739d3db9
Commit
739d3db9
authored
Oct 18, 2024
by
Andriy Roshchenko
Browse files
Enable OCP build of example_gemm_xdl_fp8.
parent
f1fe1ce6
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
171 additions
and
28 deletions
+171
-28
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+2
-2
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+4
-2
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+85
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+16
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+44
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+4
-4
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
...library/reference_tensor_operation/gpu/reference_gemm.hpp
+14
-14
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+2
-2
No files found.
example/01_gemm/run_gemm_example.inc
View file @
739d3db9
...
...
@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch
(
config
.
init_method
)
{
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
739d3db9
...
...
@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_fnuz_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_fnuz_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
fp8_storage_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
739d3db9
...
...
@@ -28,6 +28,12 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OFP8_CVT_FAST_PATH 1
#else
#define CK_OFP8_CVT_FAST_PATH 0
#endif
typedef
unsigned
char
fp8_storage_t
;
/**
...
...
@@ -52,6 +58,9 @@ enum ck_saturation_t
namespace
fp8_impl
{
typedef
fp8_storage_t
fp8x2_storage_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
float
float2_t
__attribute__
((
ext_vector_type
(
2
)));
__host__
__device__
static
inline
constexpr
bool
fnuz_f8_is_nan
(
f8_fnuz_t
a
)
{
return
static_cast
<
unsigned
char
>
(
a
)
==
0x80
;
...
...
@@ -250,6 +259,33 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
return
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2
(
fp8x2_storage_t
v
)
{
// union
// {
// unsigned int i32val;
// unsigned short i16val[2];
// } val;
// val.i16val[0] = v;
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
CK_E4M3_FNUZ
||
interpret
==
CK_E4M3_OCP
||
interpret
==
CK_E5M2_FNUZ
||
interpret
==
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
CK_E4M3_FNUZ
)
||
(
interpret
==
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
false
);
}
else
{
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
}
}
#endif
}
// namespace fp8_impl
...
...
@@ -276,7 +312,7 @@ struct f8_ocp_t
__host__
explicit
operator
float
()
const
#endif
{
#if
defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if
CK_OFP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
...
...
@@ -290,7 +326,7 @@ struct f8_ocp_t
__host__
explicit
operator
_Float16
()
const
#endif
{
#if
defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if
CK_OFP8_CVT_FAST_PATH
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
...
...
@@ -299,6 +335,53 @@ struct f8_ocp_t
}
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
;
template
<
index_t
N
>
struct
non_native_vector_base
<
f8_ocp_t
,
N
>
{
using
data_t
=
f8_ocp_t
::
data_type
;
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
)
*
N
)));
using
type
=
non_native_vector_base
<
f8_ocp_t
,
N
>
;
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
};
template
<
>
struct
non_native_vector_base
<
f8_ocp_t
,
2
>
{
using
data_t
=
f8_ocp_t
::
data_type
;
using
type
=
non_native_vector_base
<
f8_ocp_t
,
2
>
;
__host__
__device__
non_native_vector_base
()
=
default
;
using
data_v
=
fp8_impl
::
fp8x2_storage_t
;
// type of storage vector
data_v
d
;
// storage vector
using
float2_t
=
fp8_impl
::
float2_t
;
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float2_t
()
const
#else
__host__
explicit
operator
float2_t
()
const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
d
);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
d
[
0
]),
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
d
[
1
])};
#endif
}
};
struct
bf8_ocp_t
{
using
data_type
=
fp8_storage_t
;
...
...
include/ck/utility/data_type.hpp
View file @
739d3db9
...
...
@@ -1031,8 +1031,22 @@ struct non_native_vector_base
__host__
__device__
non_native_vector_base
()
=
default
;
typedef
char
data_v
__attribute__
((
ext_vector_type
(
sizeof
(
T
)
*
N
)));
data_v
d
;
T
d
[
N
];
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
f8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
...
...
include/ck/utility/type_convert.hpp
View file @
739d3db9
...
...
@@ -404,6 +404,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_ocp_t
>
(
x
);
#endif
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_fnuz_t
>
(
f8_fnuz_t
x
)
...
...
@@ -461,6 +472,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_ocp_t
>
(
x
);
#endif
}
// convert fp8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_fnuz_t
>
(
f8_fnuz_t
x
)
...
...
@@ -485,6 +507,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_ocp_t
>
(
x
);
#endif
}
// convert bf8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_fnuz_t
>
(
bf8_fnuz_t
x
)
...
...
@@ -512,6 +545,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_ocp_t
>
(
x
);
#endif
}
// convert bf8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_fnuz_t
>
(
bf8_fnuz_t
x
)
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
739d3db9
...
...
@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
ComputeTypeA
v_a
=
0
;
ComputeTypeB
v_b
=
0
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
{
0
}
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
...
@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
v_c
=
0
;
CDataType
v_c
{
0
}
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
View file @
739d3db9
...
...
@@ -25,7 +25,7 @@ template <typename ALayout,
typename
ComputeTypeB
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
naive_gemm_kernel
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
@@ -45,10 +45,10 @@ __global__ void
if
(
row_idx
<
m
&&
col_idx
<
n
)
{
AccDataType
v_acc
=
static_cast
<
AccDataType
>
(
0.0
)
;
ComputeTypeA
v_a
=
static_cast
<
ComputeTypeA
>
(
0.0
)
;
ComputeTypeB
v_b
=
static_cast
<
ComputeTypeB
>
(
0.0
)
;
CDataType
v_c
=
static_cast
<
CDataType
>
(
0.0
)
;
AccDataType
v_acc
{
0
}
;
ComputeTypeA
v_a
{
0
}
;
ComputeTypeB
v_b
{
0
}
;
CDataType
v_c
{
0
}
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
739d3db9
...
...
@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
b
half_t
operator
()(
Is
...)
ck
::
half_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
half_t
>
(
value
);
}
...
...
@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
bhal
f_t
operator
()(
Is
...)
ck
::
f
8
_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment