Unverified Commit 34a20c49 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Logs] Change flashinfer sampler logs to once (#21759)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 31084b3b
...@@ -33,7 +33,7 @@ class TopKTopPSampler(nn.Module): ...@@ -33,7 +33,7 @@ class TopKTopPSampler(nn.Module):
if is_flashinfer_available: if is_flashinfer_available:
flashinfer_version = flashinfer.__version__ flashinfer_version = flashinfer.__version__
if flashinfer_version < "0.2.3": if flashinfer_version < "0.2.3":
logger.warning( logger.warning_once(
"FlashInfer version >= 0.2.3 required. " "FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation.") "Falling back to default sampling implementation.")
self.forward = self.forward_native self.forward = self.forward_native
...@@ -46,17 +46,18 @@ class TopKTopPSampler(nn.Module): ...@@ -46,17 +46,18 @@ class TopKTopPSampler(nn.Module):
# None means False, while in V1, None means True. This is # None means False, while in V1, None means True. This is
# why we use the condition # why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.") logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda self.forward = self.forward_cuda
else: else:
logger.warning( logger.warning_once(
"FlashInfer is available, but it is not enabled. " "FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of " "Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, " "top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.") "please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native self.forward = self.forward_native
else: else:
logger.warning( logger.warning_once(
"FlashInfer is not available. Falling back to the PyTorch-" "FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the " "native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.") "best performance, please install FlashInfer.")
...@@ -97,9 +98,9 @@ class TopKTopPSampler(nn.Module): ...@@ -97,9 +98,9 @@ class TopKTopPSampler(nn.Module):
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
if generators: if generators:
logger.warning("FlashInfer 0.2.3+ does not support " logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to " "per-request generators. Falling back to "
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits. # flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment