"vscode:/vscode.git/clone" did not exist on "73a008d47909f0feb2e614682efc1af5e9b5fb3f"
Commit d2fdeac2 authored by maxiao1's avatar maxiao1
Browse files

调用vllm里custom all reduce

parent 75cd34d1
......@@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var(
if not is_hpu():
# ROCm does not use vllm custom allreduce
if use_vllm_custom_allreduce and not is_hip():
# if use_vllm_custom_allreduce and not is_hip():
if use_vllm_custom_allreduce:
try:
import vllm._C # noqa: F401
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
......@@ -34,9 +36,11 @@ if not is_hpu():
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu():
# if not is_hip() and not is_npu():
if not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
else:
custom_op = sgl_kernel.allreduce
......
......@@ -27,7 +27,8 @@ _is_hip = is_hip()
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# if ops.use_vllm_custom_allreduce and not _is_hip:
if ops.use_vllm_custom_allreduce:
# Use vLLM custom allreduce
ops.meta_size()
else:
......
......@@ -1539,7 +1539,6 @@ def initialize_model_parallel(
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile,
use_custom_allreduce = False,
)
if duplicate_tp_group:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment