Commit 9b861808 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make DeepSpeed optional

parent 15b75e49
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import deepspeed
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable, Optional
......@@ -23,7 +26,11 @@ BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
deepspeed_is_configured = (
deepspeed_is_installed and
deepspeed.checkpointing.is_configured()
)
if(deepspeed_is_configured):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment