"test/srt/vscode:/vscode.git/clone" did not exist on "d26ca84f39ab322773defc126973549f83b8954f"
Unverified Commit 8107ee62 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Add missing function and parameters (#1493)

parent c1569892
......@@ -39,7 +39,8 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
return seqstarts;
}
std::vector<int32_t> generate_seqlens(unsigned count,
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max
......@@ -53,7 +54,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(1 < count)
if(mode == mode_enum::group && 1 < count)
{
using size_type = std::vector<int32_t>::size_type;
......@@ -67,7 +68,7 @@ std::vector<int32_t> generate_seqlens(unsigned count,
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than seqlen_min
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
......@@ -88,6 +89,16 @@ std::vector<int32_t> generate_seqlens(unsigned count,
return seqlens;
}
std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt)
{
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int>
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
......@@ -220,9 +231,9 @@ decode_seqlen(mode_enum mode,
}
if(idx < batch)
{
auto rem_q = generate_seqlens(batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k =
generate_seqlens(batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment