Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0c0ddeff
Commit
0c0ddeff
authored
Feb 28, 2024
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into navi3_rel
parents
08ab9cfa
d0c7b451
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
13 deletions
+16
-13
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+1
-1
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+4
-3
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+10
-8
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
...r/include/profiler/profile_elementwise_layernorm_impl.hpp
+1
-1
No files found.
example/01_gemm/common.hpp
View file @
0c0ddeff
...
@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
...
@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
...
...
example/01_gemm/run_gemm_example.inc
View file @
0c0ddeff
...
@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
break
;
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
0.1
f
,
0.1
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
0.1
f
,
0.1
f
}(
b_k_n
);
}
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
1
e
-
1
,
1
e
-
1
);
#endif
#endif
}
}
...
...
include/ck/utility/type_convert.hpp
View file @
0c0ddeff
...
@@ -164,11 +164,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
...
@@ -164,11 +164,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
union
union
{
{
float
fval
;
float
fval
;
...
@@ -201,7 +202,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -201,7 +202,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
@@ -213,7 +214,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -213,7 +214,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template
<
>
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
...
@@ -248,7 +249,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -248,7 +249,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
@@ -264,9 +265,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
...
@@ -264,9 +265,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
union
union
{
{
float
fval
;
float
fval
;
...
...
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
View file @
0c0ddeff
...
@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
...
@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
=
bool
pass
=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1
e-3
,
1
e-3
);
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
5
e-3
,
5
e-3
);
if
(
do_log
)
if
(
do_log
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment