Commit 9b861808 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make DeepSpeed optional

parent 15b75e49
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import deepspeed deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable, Optional from typing import Any, Tuple, List, Callable, Optional
...@@ -23,7 +26,11 @@ BLOCK_ARGS = List[BLOCK_ARG] ...@@ -23,7 +26,11 @@ BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn(): def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()): deepspeed_is_configured = (
deepspeed_is_installed and
deepspeed.checkpointing.is_configured()
)
if(deepspeed_is_configured):
checkpoint = deepspeed.checkpointing.checkpoint checkpoint = deepspeed.checkpointing.checkpoint
else: else:
checkpoint = torch.utils.checkpoint.checkpoint checkpoint = torch.utils.checkpoint.checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment