Commit a6ec1782 authored by Tri Dao's avatar Tri Dao
Browse files

Bump to v0.2.6

parent 63670fd8
...@@ -436,7 +436,7 @@ class MHA(nn.Module): ...@@ -436,7 +436,7 @@ class MHA(nn.Module):
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal). # If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False. # If we're decoding, then causal=False.
causal = False if inference_params.sequence_len_offset == 0 else None causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal) context = self.inner_cross_attn(q, kv, causal=causal)
else: else:
if not self.return_residual: if not self.return_residual:
......
...@@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length): ...@@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length):
inference_params.sequence_len_offset = seqlen_og inference_params.sequence_len_offset = seqlen_og
while True: while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device) dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1] inference_params=inference_params).logits[:, -1]
scores.append(logits) scores.append(logits)
......
...@@ -156,7 +156,7 @@ ext_modules.append( ...@@ -156,7 +156,7 @@ ext_modules.append(
setup( setup(
name="flash_attn", name="flash_attn",
version="0.2.5", version="0.2.6-1",
packages=find_packages( packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment