"tests/vscode:/vscode.git/clone" did not exist on "8195824206ad2e3c45d1807b321c11f06ccb3a91"
Unverified Commit 5fcb0cdd authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Use FP32 for Gumbel Noise (#34854)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent c878b43b
...@@ -85,10 +85,10 @@ def _gumbel_sample_kernel( ...@@ -85,10 +85,10 @@ def _gumbel_sample_kernel(
pos = tl.load(pos_ptr + batch_idx) pos = tl.load(pos_ptr + batch_idx)
gumbel_seed = tl.randint(seed, pos) gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise. # Generate gumbel noise in FP32.
r = tl.rand(gumbel_seed, block).to(tl.float64) u = tl.rand(gumbel_seed, block)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) u = tl.maximum(u, 1e-7)
gumbel_noise = gumbel_noise.to(tl.float32) gumbel_noise = -tl.log(-tl.log(u))
# Apply temperature. # Apply temperature.
if APPLY_TEMPERATURE: if APPLY_TEMPERATURE:
...@@ -99,18 +99,17 @@ def _gumbel_sample_kernel( ...@@ -99,18 +99,17 @@ def _gumbel_sample_kernel(
# Apply gumbel noise. # Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf")) logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0) value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id) tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value) tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample( def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size] logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [num_reqs] idx_mapping: torch.Tensor, # [max_num_reqs]
temperature: torch.Tensor, # [num_reqs] temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [num_reqs] seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs] pos: torch.Tensor, # [num_reqs]
apply_temperature: bool, apply_temperature: bool,
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment