Unverified Commit a530b3ff authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[RL] fix register the same ops multiple times (#9564)

parent 603b3446
......@@ -146,27 +146,21 @@ def _quant_dequant_mxfp4_fake(
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
except AttributeError as error:
raise error
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
except AttributeError as error:
raise error
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
class Mxfp4Config(QuantizationConfig):
......
......@@ -1665,9 +1665,29 @@ def direct_register_custom_op(
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
Note: This function will silently skip registration if the operator
with the same name is already registered to avoid RuntimeError in
multi-engine scenarios (e.g., VERL framework).
"""
import torch.library
my_lib = target_lib or sglang_lib
# Check if operator is already registered to avoid duplicate registration
# This is important for scenarios where multiple SGLang engines run in the same process
try:
# Try to access the operator to see if it's already registered
lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
if hasattr(torch.ops, lib_name) and hasattr(
getattr(torch.ops, lib_name), op_name
):
# Operator already exists, skip registration
return
except (AttributeError, RuntimeError):
# Operator doesn't exist, proceed with registration
pass
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else:
......@@ -1676,11 +1696,22 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or sglang_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
try:
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
except RuntimeError as error:
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
# Silently ignore duplicate registration errors
# This can happen in multi-engine scenarios
pass
else:
# Re-raise other RuntimeErrors
raise error
except AttributeError as error:
# Always re-raise AttributeError as it indicates missing dependencies
raise error
def set_gpu_proc_affinity(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment