"test/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "d9b31011aa836d7ddc3eabd88a888b36056c0334"
Unverified Commit 12c39e56 authored by Jiewen Tan's avatar Jiewen Tan Committed by GitHub
Browse files

Fix use_cache for xla fsdp (#30353)

* Fix use_cache for xla fsdp

* Fix linters
parent b8b1e442
......@@ -1682,6 +1682,12 @@ class Trainer:
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
if model.config.use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
model.config.use_cache = False
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment