Unverified Commit 6b9a3334 authored by Juan Acevedo's avatar Juan Acevedo Committed by GitHub
Browse files

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)



reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.
Co-authored-by: default avatarJuan Acevedo <jfacevedo@google.com>
parent 8ead643b
...@@ -9,6 +9,7 @@ import torch_xla.debug.metrics as met ...@@ -9,6 +9,7 @@ import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp import torch_xla.debug.profiler as xp
import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr import torch_xla.runtime as xr
from torch_xla.experimental.custom_kernel import FlashAttention
from diffusers import FluxPipeline from diffusers import FluxPipeline
...@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id): ...@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id):
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16 ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
).to(device0) ).to(device0)
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
"block_q": 1536,
"block_k_major": 1536,
"block_k": 1536,
"block_b": 1536,
"block_q_major_dkv": 1536,
"block_k_major_dkv": 1536,
"block_q_dkv": 1536,
"block_k_dkv": 1536,
"block_q_dq": 1536,
"block_k_dq": 1536,
"block_k_major_dq": 1536,
}
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
width = args.width width = args.width
...@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id): ...@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id):
xm.set_rng_state(seed=unique_seed, device=device0) xm.set_rng_state(seed=unique_seed, device=device0)
times = [] times = []
logger.info("starting inference run...") logger.info("starting inference run...")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
for _ in range(args.itters): for _ in range(args.itters):
ts = perf_counter() ts = perf_counter()
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
if args.profile: if args.profile:
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
...@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id): ...@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id):
if index == 0: if index == 0:
logger.info(f"inference time: {inference_time}") logger.info(f"inference time: {inference_time}")
times.append(inference_time) times.append(inference_time)
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.") logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
image.save(f"/tmp/inference_out-{index}.png") image.save(f"/tmp/inference_out-{index}.png")
if index == 0: if index == 0:
metrics_report = met.metrics_report() metrics_report = met.metrics_report()
......
...@@ -2339,7 +2339,9 @@ class FluxAttnProcessor2_0: ...@@ -2339,7 +2339,9 @@ class FluxAttnProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment