"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "62920e378d2c7fad554b3ac479e0dd8f8e87d409"
Commit 8a325c18 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_OPT_OP to use opt kernels

parent ef037256
...@@ -128,7 +128,7 @@ class PagedAttention: ...@@ -128,7 +128,7 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v1_opt( ops.paged_attention_v1_opt(
output, output,
query, query,
...@@ -191,7 +191,7 @@ class PagedAttention: ...@@ -191,7 +191,7 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v2_opt( ops.paged_attention_v2_opt(
output, output,
exp_sums, exp_sums,
......
...@@ -10,7 +10,7 @@ if TYPE_CHECKING: ...@@ -10,7 +10,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False VLLM_USE_FLASH_ATTN_AUTO: bool = False
USE_VLLM_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
...@@ -142,8 +142,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -142,8 +142,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"USE_VLLM_OPT_OP": "VLLM_USE_OPT_OP":
lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm print pa parameters # flag to control if vllm print pa parameters
......
...@@ -35,7 +35,7 @@ class SiluAndMul(CustomOp): ...@@ -35,7 +35,7 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.silu_and_mul(out, x) ops.silu_and_mul(out, x)
else: else:
ops.silu_and_mul(out, x) ops.silu_and_mul(out, x)
...@@ -70,12 +70,12 @@ class GeluAndMul(CustomOp): ...@@ -70,12 +70,12 @@ class GeluAndMul(CustomOp):
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none": if self.approximate == "none":
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.gelu_and_mul_opt(out, x) ops.gelu_and_mul_opt(out, x)
else: else:
ops.gelu_and_mul(out, x) ops.gelu_and_mul(out, x)
elif self.approximate == "tanh": elif self.approximate == "tanh":
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.gelu_tanh_and_mul_opt(out, x) ops.gelu_tanh_and_mul_opt(out, x)
else: else:
ops.gelu_tanh_and_mul(out, x) ops.gelu_tanh_and_mul(out, x)
......
...@@ -52,7 +52,7 @@ class RMSNorm(CustomOp): ...@@ -52,7 +52,7 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.fused_add_rms_norm_opt( ops.fused_add_rms_norm_opt(
x, x,
residual, residual,
...@@ -68,7 +68,7 @@ class RMSNorm(CustomOp): ...@@ -68,7 +68,7 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
if envs.USE_VLLM_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt( ops.rms_norm_opt(
out, out,
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment