Commit 29a9e952 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove pa

parent 941c2260
...@@ -260,11 +260,6 @@ set(VLLM_EXT_SRC ...@@ -260,11 +260,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu" "csrc/opt/activation_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu"
"csrc/attention/attention_kernels_opt_tc.cu"
"csrc/attention/attention_with_mask_kernels.cu"
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
# "csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu" "csrc/sampler.cu"
......
...@@ -119,42 +119,6 @@ def main( ...@@ -119,42 +119,6 @@ def main(
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
...@@ -173,44 +137,6 @@ def main( ...@@ -173,44 +137,6 @@ def main(
) )
elif version == "v2": elif version == "v2":
if not args.custom_paged_attn: if not args.custom_paged_attn:
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
...@@ -251,24 +177,6 @@ def main( ...@@ -251,24 +177,6 @@ def main(
k_scale, k_scale,
v_scale, v_scale,
) )
elif version == "v12":
from flash_attn import vllm_flash_attn_with_kvcache
vllm_flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=seq_lens,
block_table=block_tables,
softmax_scale=scale,
causal=True,
window_size=sliding_window,
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
return_softmax_lse=False,
k_scale=k_scale,
v_scale=v_scale,
kv_cache_dtype=kv_cache_dtype,
).squeeze(1)
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -298,7 +206,7 @@ if __name__ == "__main__": ...@@ -298,7 +206,7 @@ if __name__ == "__main__":
) )
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
...@@ -324,12 +232,6 @@ if __name__ == "__main__": ...@@ -324,12 +232,6 @@ if __name__ == "__main__":
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)") "ROCm (hcu) supports fp8 (=fp8_e4m3)")
parser.add_argument(
"--gc-paged-attn", action="store_true", help="Use gc paged attention"
)
parser.add_argument(
"--tc-paged-attn", action="store_true", help="Use tc paged attention"
)
parser.add_argument( parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention" "--custom-paged-attn", action="store_true", help="Use custom paged attention"
) )
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 32) { \
constexpr static int HEAD_SIZE = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM == 64) { \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 80) { \
constexpr static int HEAD_SIZE = 80; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int HEAD_SIZE = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
}else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 32) { \
constexpr static int HEAD_SIZE = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM == 64) { \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 80) { \
constexpr static int HEAD_SIZE = 80; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int HEAD_SIZE = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (reusekv==2){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
...@@ -52,126 +52,6 @@ void paged_attention_v2( ...@@ -52,126 +52,6 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// paged_attention with attn_masks
void paged_attention_v1_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void merge_attn_states(torch::Tensor& output, void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_output,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -33,7 +33,6 @@ from vllm.transformers_utils.runai_utils import (ObjectStorageModel, ...@@ -33,7 +33,6 @@ from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
is_runai_obj_uri) is_runai_obj_uri)
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -1451,7 +1450,7 @@ class ModelConfig: ...@@ -1451,7 +1450,7 @@ class ModelConfig:
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property @property
def is_matryoshka(self) -> bool: def is_matryoshka(self) -> bool:
......
...@@ -212,7 +212,6 @@ if TYPE_CHECKING: ...@@ -212,7 +212,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_TREE_DECODING: bool = False VLLM_TREE_DECODING: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False VLLM_SPEC_DECODE_EAGER: bool = False
...@@ -1541,11 +1540,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1541,11 +1540,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters # flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM": "VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.parameter import BasevLLMParameter ...@@ -24,7 +24,6 @@ from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
...@@ -174,9 +173,8 @@ class VocabParallelEmbeddingShardIndices: ...@@ -174,9 +173,8 @@ class VocabParallelEmbeddingShardIndices:
assert self.num_added_elements <= self.num_added_elements_padded assert self.num_added_elements <= self.num_added_elements_padded
if SUPPORT_TC: @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask(
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int, input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int, org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int, added_vocab_start_index: int,
...@@ -194,26 +192,6 @@ if SUPPORT_TC: ...@@ -194,26 +192,6 @@ if SUPPORT_TC:
vocab_mask = org_vocab_mask | added_vocab_mask vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset) input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask return input_, ~vocab_mask
else:
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
@CustomOp.register("vocab_parallel_embedding") @CustomOp.register("vocab_parallel_embedding")
......
...@@ -16,13 +16,6 @@ from vllm.utils import cuda_device_count_stateless ...@@ -16,13 +16,6 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
......
...@@ -92,9 +92,6 @@ DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 ...@@ -92,9 +92,6 @@ DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
SUPPORT_TC = any(arch in GPU_ARCH for arch in ["gfx928", "gfx936", "gfx938"])
# Constants related to forcing the attention backend selection # Constants related to forcing the attention backend selection
# String name of register which may be set in order to # String name of register which may be set in order to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment