Unverified Commit a530b3ff authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[RL] fix register the same ops multiple times (#9564)

parent 603b3446
...@@ -146,27 +146,21 @@ def _quant_dequant_mxfp4_fake( ...@@ -146,27 +146,21 @@ def _quant_dequant_mxfp4_fake(
return torch.empty_like(x) return torch.empty_like(x)
try: direct_register_custom_op(
direct_register_custom_op( op_name="dequant_mxfp4",
op_name="dequant_mxfp4", op_func=_dequant_mxfp4,
op_func=_dequant_mxfp4, mutates_args=[],
mutates_args=[], fake_impl=_dequant_mxfp4_fake,
fake_impl=_dequant_mxfp4_fake, )
) dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
except AttributeError as error: direct_register_custom_op(
raise error op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
try: mutates_args=[],
direct_register_custom_op( fake_impl=_quant_dequant_mxfp4_fake,
op_name="quant_dequant_mxfp4", )
op_func=_quant_dequant_mxfp4, quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
except AttributeError as error:
raise error
class Mxfp4Config(QuantizationConfig): class Mxfp4Config(QuantizationConfig):
......
...@@ -1665,9 +1665,29 @@ def direct_register_custom_op( ...@@ -1665,9 +1665,29 @@ def direct_register_custom_op(
IMPORTANT: the lifetime of the operator is tied to the lifetime of the IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library, library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used. make sure the library object is alive when the operator is used.
Note: This function will silently skip registration if the operator
with the same name is already registered to avoid RuntimeError in
multi-engine scenarios (e.g., VERL framework).
""" """
import torch.library import torch.library
my_lib = target_lib or sglang_lib
# Check if operator is already registered to avoid duplicate registration
# This is important for scenarios where multiple SGLang engines run in the same process
try:
# Try to access the operator to see if it's already registered
lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
if hasattr(torch.ops, lib_name) and hasattr(
getattr(torch.ops, lib_name), op_name
):
# Operator already exists, skip registration
return
except (AttributeError, RuntimeError):
# Operator doesn't exist, proceed with registration
pass
if hasattr(torch.library, "infer_schema"): if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else: else:
...@@ -1676,11 +1696,22 @@ def direct_register_custom_op( ...@@ -1676,11 +1696,22 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or sglang_lib try:
my_lib.define(op_name + schema_str) my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA") my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
except RuntimeError as error:
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
# Silently ignore duplicate registration errors
# This can happen in multi-engine scenarios
pass
else:
# Re-raise other RuntimeErrors
raise error
except AttributeError as error:
# Always re-raise AttributeError as it indicates missing dependencies
raise error
def set_gpu_proc_affinity( def set_gpu_proc_affinity(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment