Commit f0cce574 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix imports

parent c269978a
...@@ -18,7 +18,9 @@ import math ...@@ -18,7 +18,9 @@ import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np import numpy as np
import deepspeed deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed): if(fa_is_installed):
...@@ -191,7 +193,11 @@ class LayerNorm(nn.Module): ...@@ -191,7 +193,11 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
d = x.dtype d = x.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm( out = nn.functional.layer_norm(
x, x,
...@@ -219,7 +225,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: ...@@ -219,7 +225,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16 type bfloat16
""" """
d = t.dtype d = t.dtype
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
)
if(d is torch.bfloat16 and not deepspeed_is_initialized):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim) s = torch.nn.functional.softmax(t, dim=dim)
else: else:
...@@ -652,7 +662,7 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -652,7 +662,7 @@ def _flash_attn(q, k, v, kv_mask):
raise ValueError( raise ValueError(
"_flash_attn requires that FlashAttention be installed" "_flash_attn requires that FlashAttention be installed"
) )
batch_dims = q.shape[:-3] batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:] no_heads, n, c = q.shape[-3:]
dtype = q.dtype dtype = q.dtype
......
...@@ -15,16 +15,11 @@ import os ...@@ -15,16 +15,11 @@ import os
import operator import operator
import time import time
import dllogger as logger
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
import numpy as np import numpy as np
import torch.cuda.profiler as profiler
from pytorch_lightning import Callback from pytorch_lightning import Callback
import torch.cuda.profiler as profiler
# We make this optional for the Colab's sake
try:
import dllogger as logger
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
except:
pass
def is_main_process(): def is_main_process():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment