Commit 5bcc463d authored by aiss's avatar aiss
Browse files

update v0.9.2

parent ac5fbab4
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from unit.common import DistributedTest
from unit.simple_model import *
from unit.checkpoint.common import checkpoint_correctness_verification
from unit.util import skip_on_arch
import pytest
......@@ -14,6 +17,8 @@ class TestPipelineCheckpoint(DistributedTest):
@pytest.mark.parametrize("zero_stage", [0, 1])
def test_checkpoint_pipe_engine(self, zero_stage, tmpdir):
skip_on_arch(min_arch=7)
config_dict = {
"train_batch_size": 2,
"train_micro_batch_size_per_gpu": 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment