Unverified Commit 14ddb430 authored by LucienXian's avatar LucienXian Committed by GitHub
Browse files

Fix meta device check failure when passing torch.device objects (#2519)



* Fix meta device check failure when passing torch.device objects
Signed-off-by: default avatarLucienXian <fl.xian@foxmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarLucienXian <fl.xian@foxmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 442513c5
......@@ -702,7 +702,8 @@ class GroupedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=device == "meta")
is_meta = torch.device(device).type == "meta"
self.reset_parameters(defer_init=is_meta)
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment