Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
309705c1
Commit
309705c1
authored
Jan 09, 2025
by
coderfeli
Browse files
add debug datra
parent
a6b761c3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
25 deletions
+30
-25
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+15
-19
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+15
-6
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
309705c1
...
@@ -155,7 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -155,7 +155,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
//
int balance = arg_parser.get_int("balance");
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
...
@@ -257,14 +257,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -257,14 +257,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
if
(
init
==
2
)
else
if
(
init
==
2
)
{
{
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillNormalDistribution
<
AScaleDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillNormalDistribution
<
AScaleDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillNormalDistribution
<
GScaleDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillNormalDistribution
<
GScaleDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0
.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
1
.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
125
f
,
0.125
f
,
seed
,
true
}(
topk_weight_host
);
}
}
// permute weight
// permute weight
...
@@ -272,15 +272,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -272,15 +272,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
// do moe sorting
// do moe sorting
if
(
balance
)
if
(
1
)
{
{
int
e_cnt
=
0
;
for
(
int
i
=
0
;
i
<
topk
;
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
topk_ids_host
.
mData
[
i
]
=
i
;
{
topk_weight_host
.
mData
[
i
]
=
0.1
;
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
e_cnt
++
;
if
(
e_cnt
>=
experts
)
e_cnt
=
0
;
}
}
}
}
else
else
...
@@ -420,7 +416,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -420,7 +416,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
experts
,
block_m
);
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Ge
lu
>
(
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Si
lu
>
(
a_host
,
a_host
,
g_host
,
g_host
,
d_host
,
d_host
,
...
@@ -529,7 +525,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -529,7 +525,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Ge
lu
>
(
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Si
lu
>
(
a_host
,
a_host
,
g_host
,
g_host
,
d_host
,
d_host
,
...
...
include/ck_tile/host/check_err.hpp
View file @
309705c1
...
@@ -62,6 +62,7 @@ check_err(const Range& out,
...
@@ -62,6 +62,7 @@ check_err(const Range& out,
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
printf
(
"1090111
\n
"
);
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
...
@@ -76,7 +77,7 @@ check_err(const Range& out,
...
@@ -76,7 +77,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -107,6 +108,7 @@ check_err(const Range& out,
...
@@ -107,6 +108,7 @@ check_err(const Range& out,
double
atol
=
1e-3
,
double
atol
=
1e-3
,
bool
allow_infinity_ref
=
false
)
bool
allow_infinity_ref
=
false
)
{
{
printf
(
"1111
\n
"
);
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
...
@@ -127,6 +129,7 @@ check_err(const Range& out,
...
@@ -127,6 +129,7 @@ check_err(const Range& out,
double
err
=
0
;
double
err
=
0
;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
int
print_cnt
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
...
@@ -136,12 +139,16 @@ check_err(const Range& out,
...
@@ -136,12 +139,16 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
else
if
(
print_cnt
<
10
)
{
print_cnt
++
;
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
}
}
if
(
!
res
)
if
(
!
res
)
...
@@ -195,7 +202,7 @@ check_err(const Range& out,
...
@@ -195,7 +202,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
@@ -235,6 +242,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -235,6 +242,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
printf
(
"222
\n
"
);
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
...
@@ -250,7 +258,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -250,7 +258,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -313,6 +321,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -313,6 +321,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
printf
(
"11113
\n
"
);
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
...
@@ -327,7 +336,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -327,7 +336,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o_fp64
<<
" != "
<<
r_fp64
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o_fp64
<<
" != "
<<
r_fp64
<<
std
::
endl
;
...
@@ -381,7 +390,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -381,7 +390,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
00
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment