Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f1c19b2
Unverified
Commit
2f1c19b2
authored
Jun 12, 2025
by
Ning Xie
Committed by
GitHub
Jun 11, 2025
Browse files
[CI] change spell checker from codespell to typos (#18711)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
42f52cc9
Changes
57
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
72 additions
and
79 deletions
+72
-79
.gitignore
.gitignore
+1
-1
.pre-commit-config.yaml
.pre-commit-config.yaml
+3
-5
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+3
-3
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+5
-5
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+8
-8
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+3
-3
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+1
-1
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+3
-3
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+7
-7
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
+1
-1
pyproject.toml
pyproject.toml
+0
-4
tests/compile/test_async_tp.py
tests/compile/test_async_tp.py
+2
-2
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+2
-2
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+3
-3
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+2
-2
tests/entrypoints/openai/test_chat_template.py
tests/entrypoints/openai/test_chat_template.py
+2
-2
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+8
-8
tests/kernels/attention/test_encoder_decoder_attn.py
tests/kernels/attention/test_encoder_decoder_attn.py
+1
-1
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+4
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+13
-14
No files found.
.gitignore
View file @
2f1c19b2
...
@@ -200,5 +200,5 @@ benchmarks/**/*.json
...
@@ -200,5 +200,5 @@ benchmarks/**/*.json
actionlint
actionlint
shellcheck*/
shellcheck*/
# I
n
gore moe/marlin_moe gen code
# Ig
n
ore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
csrc/moe/marlin_moe_wna16/kernel_*
.pre-commit-config.yaml
View file @
2f1c19b2
...
@@ -20,12 +20,10 @@ repos:
...
@@ -20,12 +20,10 @@ repos:
args
:
[
--output-format
,
github
,
--fix
]
args
:
[
--output-format
,
github
,
--fix
]
-
id
:
ruff-format
-
id
:
ruff-format
files
:
^(.buildkite|benchmarks|examples)/.*
files
:
^(.buildkite|benchmarks|examples)/.*
-
repo
:
https://github.com/c
odespell-project/codespell
-
repo
:
https://github.com/c
rate-ci/typos
rev
:
v
2.4.1
rev
:
v
1.32.0
hooks
:
hooks
:
-
id
:
codespell
-
id
:
typos
additional_dependencies
:
[
'
tomli'
]
args
:
[
'
--toml'
,
'
pyproject.toml'
]
-
repo
:
https://github.com/PyCQA/isort
-
repo
:
https://github.com/PyCQA/isort
rev
:
6.0.1
rev
:
6.0.1
hooks
:
hooks
:
...
...
csrc/cpu/attention.cpp
View file @
2f1c19b2
...
@@ -137,8 +137,8 @@ FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
...
@@ -137,8 +137,8 @@ FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
}
}
template
<
typename
T
>
template
<
typename
T
>
FORCE_INLINE
void
reducePartitonSoftmax
(
const
T
*
max_data
,
T
*
sum_data
,
FORCE_INLINE
void
reducePartit
i
onSoftmax
(
const
T
*
max_data
,
T
*
sum_data
,
const
int
size
)
{
const
int
size
)
{
T
max
=
max_data
[
0
];
T
max
=
max_data
[
0
];
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
size
;
++
i
)
{
max
=
max
>=
max_data
[
i
]
?
max
:
max_data
[
i
];
max
=
max
>=
max_data
[
i
]
?
max
:
max_data
[
i
];
...
@@ -634,7 +634,7 @@ struct paged_attention_v2_impl {
...
@@ -634,7 +634,7 @@ struct paged_attention_v2_impl {
if
(
partition_num
==
1
)
continue
;
if
(
partition_num
==
1
)
continue
;
reducePartitonSoftmax
(
reducePartit
i
onSoftmax
(
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
,
head_idx
*
max_num_partitions
,
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
...
...
csrc/cpu/cpu_types_x86.hpp
View file @
2f1c19b2
...
@@ -83,7 +83,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
...
@@ -83,7 +83,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
explicit
FP16Vec16
(
const
void
*
ptr
)
explicit
FP16Vec16
(
const
void
*
ptr
)
:
reg
((
__m256i
)
_mm256_loadu_si256
((
__m256i
*
)
ptr
))
{}
:
reg
((
__m256i
)
_mm256_loadu_si256
((
__m256i
*
)
ptr
))
{}
// non-temp
r
oal load
// non-tempo
r
al load
explicit
FP16Vec16
(
bool
,
void
*
ptr
)
explicit
FP16Vec16
(
bool
,
void
*
ptr
)
:
reg
(
_mm256_stream_load_si256
((
__m256i
*
)
ptr
))
{}
:
reg
(
_mm256_stream_load_si256
((
__m256i
*
)
ptr
))
{}
...
@@ -120,7 +120,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
...
@@ -120,7 +120,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
explicit
BF16Vec16
(
const
void
*
ptr
)
explicit
BF16Vec16
(
const
void
*
ptr
)
:
reg
((
__m256i
)
_mm256_loadu_si256
((
__m256i
*
)
ptr
))
{}
:
reg
((
__m256i
)
_mm256_loadu_si256
((
__m256i
*
)
ptr
))
{}
// non-temp
r
oal load
// non-tempo
r
al load
explicit
BF16Vec16
(
bool
,
void
*
ptr
)
explicit
BF16Vec16
(
bool
,
void
*
ptr
)
:
reg
(
_mm256_stream_load_si256
((
__m256i
*
)
ptr
))
{}
:
reg
(
_mm256_stream_load_si256
((
__m256i
*
)
ptr
))
{}
...
@@ -327,7 +327,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
...
@@ -327,7 +327,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
// normal load
// normal load
explicit
FP32Vec16
(
const
float
*
ptr
)
:
reg
(
_mm512_loadu_ps
(
ptr
))
{}
explicit
FP32Vec16
(
const
float
*
ptr
)
:
reg
(
_mm512_loadu_ps
(
ptr
))
{}
// non-temp
r
oal load
// non-tempo
r
al load
explicit
FP32Vec16
(
bool
,
void
*
ptr
)
explicit
FP32Vec16
(
bool
,
void
*
ptr
)
:
reg
((
__m512
)
_mm512_stream_load_si512
(
ptr
))
{}
:
reg
((
__m512
)
_mm512_stream_load_si512
(
ptr
))
{}
...
@@ -576,7 +576,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
...
@@ -576,7 +576,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
// normal load
// normal load
explicit
INT8Vec64
(
void
*
ptr
)
:
reg
(
_mm512_loadu_epi8
(
ptr
))
{}
explicit
INT8Vec64
(
void
*
ptr
)
:
reg
(
_mm512_loadu_epi8
(
ptr
))
{}
// non-temp
r
oal load
// non-tempo
r
al load
explicit
INT8Vec64
(
bool
,
void
*
ptr
)
:
reg
(
_mm512_stream_load_si512
(
ptr
))
{}
explicit
INT8Vec64
(
bool
,
void
*
ptr
)
:
reg
(
_mm512_stream_load_si512
(
ptr
))
{}
void
save
(
void
*
ptr
)
const
{
_mm512_storeu_epi8
(
ptr
,
reg
);
}
void
save
(
void
*
ptr
)
const
{
_mm512_storeu_epi8
(
ptr
,
reg
);
}
...
@@ -587,7 +587,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
...
@@ -587,7 +587,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
_mm512_mask_storeu_epi8
(
ptr
,
mask
,
reg
);
_mm512_mask_storeu_epi8
(
ptr
,
mask
,
reg
);
}
}
// non-temp
r
oal save
// non-tempo
r
al save
void
nt_save
(
int8_t
*
ptr
)
{
_mm512_stream_si512
((
__m512i
*
)
ptr
,
reg
);
}
void
nt_save
(
int8_t
*
ptr
)
{
_mm512_stream_si512
((
__m512i
*
)
ptr
,
reg
);
}
};
};
#endif
#endif
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
2f1c19b2
...
@@ -12,7 +12,7 @@ void moe_permute(
...
@@ -12,7 +12,7 @@ void moe_permute(
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indic
i
es
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indices
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
...
@@ -27,15 +27,15 @@ void moe_permute(
...
@@ -27,15 +27,15 @@ void moe_permute(
"expert_first_token_offset must be int64"
);
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
"topk_ids must be int32"
);
TORCH_CHECK
(
token_expert_indic
i
es
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
TORCH_CHECK
(
token_expert_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"token_expert_indic
i
es must be int32"
);
"token_expert_indices must be int32"
);
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"src_row_id2dst_row_id_map must be int32"
);
"src_row_id2dst_row_id_map must be int32"
);
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
"expert_first_token_offset shape != n_local_expert+1"
)
"expert_first_token_offset shape != n_local_expert+1"
)
TORCH_CHECK
(
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indic
i
es
.
sizes
(),
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indices
.
sizes
(),
"token_expert_indic
i
es shape must be same as src_row_id2dst_row_id_map"
);
"token_expert_indices shape must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
auto
align_block_size_value
=
...
@@ -71,7 +71,7 @@ void moe_permute(
...
@@ -71,7 +71,7 @@ void moe_permute(
expert_map_ptr
,
n_expert
,
stream
);
expert_map_ptr
,
n_expert
,
stream
);
}
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indic
i
es
),
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indices
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
...
@@ -190,7 +190,7 @@ void shuffle_rows(const torch::Tensor& input_tensor,
...
@@ -190,7 +190,7 @@ void shuffle_rows(const torch::Tensor& input_tensor,
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indic
i
es
,
const
torch
::
Tensor
&
token_expert_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
...
@@ -203,7 +203,7 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
...
@@ -203,7 +203,7 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
void
moe_unpermute
(
const
torch
::
Tensor
&
input
,
void
moe_unpermute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indic
i
es
,
const
torch
::
Tensor
&
token_expert_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
2f1c19b2
...
@@ -425,7 +425,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
...
@@ -425,7 +425,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indic
i
es, \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
stream);
...
@@ -433,7 +433,7 @@ template <typename IndType>
...
@@ -433,7 +433,7 @@ template <typename IndType>
void
topkGatingSoftmaxKernelLauncher
(
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
const
float
*
gating_output
,
float
*
topk_weights
,
float
*
topk_weights
,
IndType
*
topk_indic
i
es
,
IndType
*
topk_indices
,
int
*
token_expert_indices
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
float
*
softmax_workspace
,
const
int
num_tokens
,
const
int
num_tokens
,
...
@@ -476,7 +476,7 @@ void topkGatingSoftmaxKernelLauncher(
...
@@ -476,7 +476,7 @@ void topkGatingSoftmaxKernelLauncher(
moeSoftmax
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
moeSoftmax
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indic
i
es
,
token_expert_indices
,
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indices
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
);
num_experts
,
topk
,
0
,
num_experts
);
}
}
}
}
...
...
csrc/moe/torch_bindings.cpp
View file @
2f1c19b2
...
@@ -66,7 +66,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -66,7 +66,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
m
.
def
(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indic
i
es, Tensor? expert_map, int n_expert,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
...
...
csrc/quantization/machete/machete_mainloop.cuh
View file @
2f1c19b2
...
@@ -1003,7 +1003,7 @@ struct MacheteCollectiveMma {
...
@@ -1003,7 +1003,7 @@ struct MacheteCollectiveMma {
static
constexpr
int
A_CPY_VEC
=
static
constexpr
int
A_CPY_VEC
=
decltype
(
max_common_vector
(
tCsA
,
tCrA_load
)){};
decltype
(
max_common_vector
(
tCsA
,
tCrA_load
)){};
static
constexpr
int
COVERSION_WIDTH
=
static
constexpr
int
CO
N
VERSION_WIDTH
=
std
::
min
(
A_CPY_VEC
,
int
(
size
<
0
>
(
tCrA_mma
)));
std
::
min
(
A_CPY_VEC
,
int
(
size
<
0
>
(
tCrA_mma
)));
auto
load_A_to_registers
=
[
&
](
int
read_stage
)
{
auto
load_A_to_registers
=
[
&
](
int
read_stage
)
{
...
@@ -1026,8 +1026,8 @@ struct MacheteCollectiveMma {
...
@@ -1026,8 +1026,8 @@ struct MacheteCollectiveMma {
// PIPELINED MAIN LOOP
// PIPELINED MAIN LOOP
//
//
auto
convert_A
=
[
&
,
a_vec
=
Int
<
COVERSION_WIDTH
>
{}](
int
k_block
,
auto
convert_A
=
[
&
,
a_vec
=
Int
<
CO
N
VERSION_WIDTH
>
{}](
int
k_block
,
int
read_stage
)
{
int
read_stage
)
{
load_extra_info_to_registers
(
partitioned_extra_info
,
load_extra_info_to_registers
(
partitioned_extra_info
,
copy_partitions_extra_info
,
k_block
,
copy_partitions_extra_info
,
k_block
,
read_stage
);
read_stage
);
...
...
csrc/rocm/skinny_gemms.cu
View file @
2f1c19b2
...
@@ -320,7 +320,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -320,7 +320,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Goal is to bring the activation matrix A to the LDS
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goin
t
to work!
// then this is not goin
g
to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
max_lds_len
];
__shared__
scalar_t
s
[
max_lds_len
];
...
@@ -581,7 +581,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -581,7 +581,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Goal is to bring the activation matrix A to the LDS
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goin
t
to work!
// then this is not goin
g
to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
max_lds_len
];
__shared__
scalar_t
s
[
max_lds_len
];
...
@@ -601,7 +601,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -601,7 +601,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
// Check whether there will be fragmenation!
// Check whether there will be fragmen
t
ation!
// This will happen only for the last wave!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
uint32_t
startColumn
=
M
-
YTILE
;
...
@@ -827,7 +827,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -827,7 +827,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
// Check whether there will be fragmenation!
// Check whether there will be fragmen
t
ation!
// This will happen only for the last wave!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
uint32_t
startColumn
=
M
-
YTILE
;
...
@@ -882,7 +882,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -882,7 +882,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Goal is to bring the activation matrix A to the LDS
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goin
t
to work!
// then this is not goin
g
to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
max_lds_len
];
__shared__
scalar_t
s
[
max_lds_len
];
...
@@ -904,7 +904,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -904,7 +904,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
//----------------------------------------------------
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
// Check whether there will be fragmenation!
// Check whether there will be fragmen
t
ation!
// This will happen only for the last wave!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
uint32_t
startColumn
=
M
-
YTILE
;
...
@@ -1176,7 +1176,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1176,7 +1176,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
kBase
=
0
;
kBase
=
0
;
// Check whether there will be fragmenation!
// Check whether there will be fragmen
t
ation!
// This will happen only for the last wave!
// This will happen only for the last wave!
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
if
(
m
<
M
&&
(
m
+
YTILE
)
>=
M
)
{
uint32_t
startColumn
=
M
-
YTILE
;
uint32_t
startColumn
=
M
-
YTILE
;
...
...
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
View file @
2f1c19b2
...
@@ -277,7 +277,7 @@ CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
...
@@ -277,7 +277,7 @@ CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
uint32_t
const
m
=
1
;
// Set M to 1 for compression
uint32_t
const
m
=
1
;
// Set M to 1 for compression
uint32_t
const
n
=
a
.
size
(
1
);
uint32_t
const
n
=
a
.
size
(
1
);
// Note: For correctess, the compressed format must be invariant in:
// Note: For correct
n
ess, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
// - CUTLASS epilogues
...
...
pyproject.toml
View file @
2f1c19b2
...
@@ -137,10 +137,6 @@ exclude = [
...
@@ -137,10 +137,6 @@ exclude = [
'vllm/attention/ops/.*\.py$'
'vllm/attention/ops/.*\.py$'
]
]
[tool.codespell]
ignore-words-list
=
"dout, te, indicies, subtile, ElementE"
skip
=
"tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora/data/*,build/*,vllm/third_party/*"
[tool.isort]
[tool.isort]
skip_glob
=
[
skip_glob
=
[
".buildkite/*"
,
".buildkite/*"
,
...
...
tests/compile/test_async_tp.py
View file @
2f1c19b2
...
@@ -223,7 +223,7 @@ def test_async_tp_pass_correctness(
...
@@ -223,7 +223,7 @@ def test_async_tp_pass_correctness(
"VLLM_USE_V1"
:
"1"
,
"VLLM_USE_V1"
:
"1"
,
}
}
a
y
snc_tp_args
=
[
as
y
nc_tp_args
=
[
*
common_args
,
*
common_args
,
"--tensor-parallel-size"
,
"--tensor-parallel-size"
,
str
(
tp_size
),
str
(
tp_size
),
...
@@ -242,7 +242,7 @@ def test_async_tp_pass_correctness(
...
@@ -242,7 +242,7 @@ def test_async_tp_pass_correctness(
]
]
compare_two_settings
(
model_id
,
compare_two_settings
(
model_id
,
a
y
snc_tp_args
,
as
y
nc_tp_args
,
tp_args
,
tp_args
,
async_tp_env
,
async_tp_env
,
tp_env
,
tp_env
,
...
...
tests/core/block/e2e/test_correctness.py
View file @
2f1c19b2
...
@@ -437,8 +437,8 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
...
@@ -437,8 +437,8 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
"enable_prefix_caching"
:
True
,
"enable_prefix_caching"
:
True
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_auto_prefix_caching_after_evition_start
(
baseline_llm_generator
,
def
test_auto_prefix_caching_after_evi
c
tion_start
(
baseline_llm_generator
,
test_llm_generator
):
test_llm_generator
):
"""Verify block manager v2 with auto prefix caching could works normal
"""Verify block manager v2 with auto prefix caching could works normal
even when eviction started.
even when eviction started.
With APC enabled, all blocks are held by native block at the beginning.
With APC enabled, all blocks are held by native block at the beginning.
...
...
tests/core/block/e2e/test_correctness_sliding_window.py
View file @
2f1c19b2
...
@@ -33,8 +33,8 @@ BLOCK_SIZE = 16
...
@@ -33,8 +33,8 @@ BLOCK_SIZE = 16
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_sliding_window_retrival
(
baseline_llm_generator
,
test_llm_generator
,
def
test_sliding_window_retri
e
val
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
asks for value of one of them (which is outside the sliding window).
...
@@ -100,7 +100,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
...
@@ -100,7 +100,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
backend
,
monkeypatch
):
"""
"""
This is similar to test_sliding_window_retrival, however, it doesn't
This is similar to test_sliding_window_retri
e
val, however, it doesn't
compare against the v1 block manager since v1 doesn't support
compare against the v1 block manager since v1 doesn't support
chunked prefill with sliding window.
chunked prefill with sliding window.
...
...
tests/core/test_scheduler.py
View file @
2f1c19b2
...
@@ -594,8 +594,8 @@ def test_decode_schedule_preempted():
...
@@ -594,8 +594,8 @@ def test_decode_schedule_preempted():
# should be preempted. 1 will also be preempted.
# should be preempted. 1 will also be preempted.
budget
=
create_token_budget
()
budget
=
create_token_budget
()
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
remainig_running
=
scheduler
.
running
remaini
n
g_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
remaini
n
g_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
output
.
decode_seq_groups
[
0
].
seq_group
.
request_id
==
"0"
assert
output
.
decode_seq_groups
[
0
].
seq_group
.
request_id
==
"0"
...
...
tests/entrypoints/openai/test_chat_template.py
View file @
2f1c19b2
...
@@ -16,7 +16,7 @@ chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
...
@@ -16,7 +16,7 @@ chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
assert
chatml_jinja_path
.
exists
()
# Define models, templates, and their corresponding expected outputs
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT
=
[
MODEL_TEMPLATE_GENERAT
I
ON_OUTPUT
=
[
(
"facebook/opt-125m"
,
chatml_jinja_path
,
True
,
False
,
"""<|im_start|>user
(
"facebook/opt-125m"
,
chatml_jinja_path
,
True
,
False
,
"""<|im_start|>user
Hello<|im_end|>
Hello<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
...
@@ -91,7 +91,7 @@ def test_no_load_chat_template_literallike():
...
@@ -91,7 +91,7 @@ def test_no_load_chat_template_literallike():
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model,template,add_generation_prompt,continue_final_message,expected_output"
,
"model,template,add_generation_prompt,continue_final_message,expected_output"
,
MODEL_TEMPLATE_GENERATON_OUTPUT
)
MODEL_TEMPLATE_GENERAT
I
ON_OUTPUT
)
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
continue_final_message
,
expected_output
):
continue_final_message
,
expected_output
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
...
...
tests/kernels/attention/test_cache.py
View file @
2f1c19b2
...
@@ -72,8 +72,8 @@ def test_copy_blocks(
...
@@ -72,8 +72,8 @@ def test_copy_blocks(
# destination blocks.
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
assert
2
*
num_mappings
<=
num_blocks
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remaini
n
g_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
dst_blocks
=
random
.
sample
(
remaini
n
g_blocks
,
2
*
num_mappings
)
block_mapping
:
list
[
tuple
[
int
,
int
]]
=
[]
block_mapping
:
list
[
tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
src
=
src_blocks
[
i
]
...
@@ -189,12 +189,12 @@ def test_reshape_and_cache(
...
@@ -189,12 +189,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
...
@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
kv_dtype
=
kv_cache_dtype
)
kv_dtype
=
kv_cache_dtype
)
# Run the reference implementation.
# Run the reference implementation.
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
if
kv_cache_layout
==
"NHD"
:
if
kv_cache_layout
==
"NHD"
:
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
...
...
tests/kernels/attention/test_encoder_decoder_attn.py
View file @
2f1c19b2
...
@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
...
@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
# Narrow te
e
st-cases for unsupported-scenario
# Narrow test-cases for unsupported-scenario
# tests
# tests
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
2f1c19b2
...
@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
...
@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_conti
n
gous"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_contig
u
ous"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
,
use_key
,
head_stride_is_conti
n
gous
):
seq_len
,
use_key
,
head_stride_is_contig
u
ous
):
batch_size
=
1
batch_size
=
1
base
=
10000
base
=
10000
num_heads
=
7
num_heads
=
7
...
@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions
=
torch
.
randint
(
0
,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
device
=
device
)
head_stride
=
head_size
+
(
64
if
head_stride_is_conti
n
gous
else
0
)
head_stride
=
head_size
+
(
64
if
head_stride_is_contig
u
ous
else
0
)
query
=
torch
.
randn
(
batch_size
,
query
=
torch
.
randn
(
batch_size
,
seq_len
,
seq_len
,
...
@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
# if we have a contiguous head stride, test the alternate
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
# [..., num_heads * head_dim] shape/layout
if
head_stride_is_conti
n
gous
:
if
head_stride_is_contig
u
ous
:
rotary_embedding_opcheck
(
rotary_embedding_opcheck
(
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
2f1c19b2
...
@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
...
@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
return
A
,
dt
,
X
,
B
,
C
return
A
,
dt
,
X
,
B
,
C
def
generate_continous_batched_examples
(
example_lens_by_batch
,
def
generate_contin
u
ous_batched_examples
(
example_lens_by_batch
,
num_examples
,
num_examples
,
full_length
,
full_length
,
last_taken
,
last_taken
,
exhausted
,
exhausted
,
n_heads
,
n_heads
,
d_head
,
d_head
,
itype
,
itype
,
device
=
'cuda'
):
device
=
'cuda'
):
# this function generates a random examples of certain length
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# and then cut according to "example_lens_by_batch" and feed
...
@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
states
=
None
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
C
)
in
generate_continous_batched_examples
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continuous_batched_examples
(
cases
,
num_examples
,
seqlen
,
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
\
chunk_indices
,
chunk_offsets
=
\
_query_start_loc_to_chunk_indices_offsets
(
_query_start_loc_to_chunk_indices_offsets
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment