Unverified Commit 32ef4983 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1] Temporarily disable FlashInfer Rejection Sampler (#14788)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent ad19c8a0
...@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module): ...@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if current_platform.is_cuda: if current_platform.is_cuda():
if is_flashinfer_available: if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
......
...@@ -24,9 +24,18 @@ class RejectionSampler(nn.Module): ...@@ -24,9 +24,18 @@ class RejectionSampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if current_platform.is_cuda: if current_platform.is_cuda():
if is_flashinfer_available: if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# FIXME(woosuk): Currently, we have errors when using
# FlashInfer for rejection sampling. As a workaround, we
# disable FlashInfer for rejection sampling by default.
logger.info("Currently, FlashInfer rejection sampler is "
"disabled because of a bug. Falling back to "
"the PyTorch-native implementation of "
"rejection sampling.")
self.forward_method = self.forward_native
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set # default it is unused). For backward compatibility, we set
...@@ -35,8 +44,8 @@ class RejectionSampler(nn.Module): ...@@ -35,8 +44,8 @@ class RejectionSampler(nn.Module):
# None means False, while in V1, None means True. This is # None means False, while in V1, None means True. This is
# why we use the condition # why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for rejection sampling.") # logger.info("Using FlashInfer for rejection sampling.")
self.forward_method = self.flashinfer_sample # self.forward_method = self.flashinfer_sample
else: else:
logger.warning( logger.warning(
"FlashInfer is available, but it is not enabled. " "FlashInfer is available, but it is not enabled. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment