"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3d2648d743e4257c550bba03242486b1f3834838"
Commit d2fdeac2 authored by maxiao1's avatar maxiao1
Browse files

调用vllm里custom all reduce

parent 75cd34d1
...@@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var( ...@@ -22,9 +22,11 @@ use_vllm_custom_allreduce = get_bool_env_var(
if not is_hpu(): if not is_hpu():
# ROCm does not use vllm custom allreduce # ROCm does not use vllm custom allreduce
if use_vllm_custom_allreduce and not is_hip(): # if use_vllm_custom_allreduce and not is_hip():
if use_vllm_custom_allreduce:
try: try:
import vllm._C # noqa: F401 import vllm._C # noqa: F401
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
else: else:
...@@ -34,9 +36,11 @@ if not is_hpu(): ...@@ -34,9 +36,11 @@ if not is_hpu():
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu(): # if not is_hip() and not is_npu():
if not is_npu():
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar custom_op = torch.ops._C_custom_ar
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
else: else:
custom_op = sgl_kernel.allreduce custom_op = sgl_kernel.allreduce
......
...@@ -27,7 +27,8 @@ _is_hip = is_hip() ...@@ -27,7 +27,8 @@ _is_hip = is_hip()
try: try:
if ops.use_vllm_custom_allreduce and not _is_hip: # if ops.use_vllm_custom_allreduce and not _is_hip:
if ops.use_vllm_custom_allreduce:
# Use vLLM custom allreduce # Use vLLM custom allreduce
ops.meta_size() ops.meta_size()
else: else:
......
...@@ -1539,7 +1539,6 @@ def initialize_model_parallel( ...@@ -1539,7 +1539,6 @@ def initialize_model_parallel(
group_name="tp", group_name="tp",
pynccl_use_current_stream=duplicate_tp_group, pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile, torch_compile=torch_compile,
use_custom_allreduce = False,
) )
if duplicate_tp_group: if duplicate_tp_group:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment