Commit 16de530e authored by yuguo's avatar yuguo
Browse files

[DCU] fix fsdp2

parent 11b6b7e4
......@@ -8,8 +8,6 @@ import os
import sys
import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import torch
import torch.distributed as dist
......@@ -20,6 +18,8 @@ from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
......
......@@ -6,10 +6,10 @@ import os
import pytest
import subprocess
from pathlib import Path
import torch
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment