Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5791c2ae
Commit
5791c2ae
authored
Nov 04, 2024
by
dummycoderfe
Browse files
opt valid and change set_value buf to 256MB
parent
0475a327
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
65 deletions
+51
-65
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+41
-41
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+4
-1
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
+6
-23
No files found.
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
5791c2ae
...
...
@@ -185,7 +185,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
,
true
,
true
,
1024
*
1024
*
1024
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
,
true
,
true
,
256
*
1024
*
1024
});
if
(
ave_time
<
0
)
{
...
...
@@ -230,46 +230,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
//
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
//
int N_ = acc_.mDesc.get_lengths()[1];
//
if(fused_quant == 1)
//
{
//
for(int n_ = 0; n_ < N_; n_++)
//
{
//
// input smooth outlier
//
acc_(m_, n_) =
//
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
//
}
//
}
//
ComputeDataType absmax = static_cast<ComputeDataType>(0);
//
for(int n_ = 0; n_ < N_; n_++)
//
{
//
const auto a = ck_tile::abs(acc_(m_, n_));
//
absmax = a > absmax ? a : absmax;
//
}
//
// printf("cpu:absmax:%f\n", absmax);
//
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
//
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
//
for(int n_ = 0; n_ < N_; n_++)
//
{
//
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
//
}
//
};
//
ck_tile::reference_layernorm2d_fwd<XDataType,
//
GammaDataType,
//
BetaDataType,
//
ComputeDataType,
//
YDataType,
//
MeanDataType,
//
InvStdDataType>(x_host,
//
gamma_host,
//
beta_host,
//
y_host_ref,
//
mean_host_ref,
//
invStd_host_ref,
//
epsilon,
//
dquant_functor);
}
else
{
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
5791c2ae
...
...
@@ -82,9 +82,12 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
if
(
s
.
clear_cache
)
{
printf
(
"setvalue to clear_cache, bufsize %lu
\n
"
,
s
.
buf_size
);
}
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
if
(
s
.
clear_cache
)
{
s
.
cache_buf
.
SetValue
<
int
>
(
i
);
s
.
cache_buf
.
SetValue
<
char
>
(
0
);
}
timer
.
start
(
s
.
stream_id_
);
(
callables
(
s
),...);
...
...
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
View file @
5791c2ae
...
...
@@ -8,25 +8,11 @@
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_layernorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
void
operator
()()
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
return
;
}
};
...
...
@@ -75,21 +61,18 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
a_
=
(
x
-
mean
)
*
divisor
;
a_
=
a_
*
gamma
+
beta
;
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
acc
(
m
,
n
)
=
a_
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
)
;
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
epilogue_functor
();
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
mean_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment