Commit f0cce574 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix imports

parent c269978a
......@@ -18,7 +18,9 @@ import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import deepspeed
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
......@@ -191,7 +193,11 @@ class LayerNorm(nn.Module):
def forward(self, x):
d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
......@@ -219,7 +225,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16
"""
d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
......@@ -652,7 +662,7 @@ def _flash_attn(q, k, v, kv_mask):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
......
......@@ -15,16 +15,11 @@ import os
import operator
import time
import dllogger as logger
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
import numpy as np
import torch.cuda.profiler as profiler
from pytorch_lightning import Callback
# We make this optional for the Colab's sake
try:
import dllogger as logger
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
except:
pass
import torch.cuda.profiler as profiler
def is_main_process():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment