Commit 8a325c18 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_OPT_OP to use opt kernels

parent ef037256
......@@ -128,7 +128,7 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v1_opt(
output,
query,
......@@ -191,7 +191,7 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v2_opt(
output,
exp_sums,
......
......@@ -10,7 +10,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
USE_VLLM_OPT_OP: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
......@@ -142,8 +142,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
("true", "1")),
# flag to control vllm to use optimized kernels
"USE_VLLM_OPT_OP":
lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
......
......@@ -35,7 +35,7 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.silu_and_mul(out, x)
else:
ops.silu_and_mul(out, x)
......@@ -70,12 +70,12 @@ class GeluAndMul(CustomOp):
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.gelu_and_mul_opt(out, x)
else:
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.gelu_tanh_and_mul_opt(out, x)
else:
ops.gelu_tanh_and_mul(out, x)
......
......@@ -52,7 +52,7 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops
if residual is not None:
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.fused_add_rms_norm_opt(
x,
residual,
......@@ -68,7 +68,7 @@ class RMSNorm(CustomOp):
)
return x, residual
out = torch.empty_like(x)
if envs.USE_VLLM_OPT_OP:
if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt(
out,
x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment