Unverified Commit 9288e759 authored by Ashok Pon Kumar's avatar Ashok Pon Kumar Committed by GitHub
Browse files

fix: Avoid error when fsdp_config is missing xla_fsdp_v2 (#29480)


Signed-off-by: default avatarAshok Pon Kumar Sree Prakash <ashokponkumar@gmail.com>
parent f6133d76
......@@ -647,7 +647,7 @@ class Trainer:
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"]
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment