Unverified Commit 28c145eb authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix typo in Pallas backend (#5558)

parent e2afb03c
...@@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower() tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
if not tpu_type.endswith("lite"): if not tpu_type.endswith("lite"):
if self.num_kv_heads % 2 == 0: if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head" self.megacore_mode = "kv_head"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment