Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8107ee62
"...resnet50_tensorflow.git" did not exist on "6894a6604e882211a0b42c6fda3d115f977f684a"
Unverified
Commit
8107ee62
authored
Aug 29, 2024
by
Po Yen Chen
Committed by
GitHub
Aug 29, 2024
Browse files
Add missing function and parameters (#1493)
parent
c1569892
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
5 deletions
+16
-5
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+16
-5
No files found.
example/ck_tile/01_fmha/utils.hpp
View file @
8107ee62
...
@@ -39,7 +39,8 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
...
@@ -39,7 +39,8 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
return
seqstarts
;
return
seqstarts
;
}
}
std
::
vector
<
int32_t
>
generate_seqlens
(
unsigned
count
,
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
...
@@ -53,7 +54,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
...
@@ -53,7 +54,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
1
<
count
)
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
{
using
size_type
=
std
::
vector
<
int32_t
>::
size_type
;
using
size_type
=
std
::
vector
<
int32_t
>::
size_type
;
...
@@ -67,7 +68,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
...
@@ -67,7 +68,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
{
const
size_type
to_decrease
=
next_idx
();
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than
seqlen_m
in
// make sure each elements of seqlens is
in range [seqlen_min,
seqlen_m
ax]
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
{
continue
;
continue
;
...
@@ -88,6 +89,16 @@ std::vector<int32_t> generate_seqlens(unsigned count,
...
@@ -88,6 +89,16 @@ std::vector<int32_t> generate_seqlens(unsigned count,
return
seqlens
;
return
seqlens
;
}
}
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
...
@@ -220,9 +231,9 @@ decode_seqlen(mode_enum mode,
...
@@ -220,9 +231,9 @@ decode_seqlen(mode_enum mode,
}
}
if
(
idx
<
batch
)
if
(
idx
<
batch
)
{
{
auto
rem_q
=
generate_seqlens
(
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
auto
rem_k
=
generate_seqlens
(
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment