Unverified Commit 2205cff2 authored by eqy's avatar eqy Committed by GitHub
Browse files

check in (#1210)

parent fa8bd7e6
......@@ -49,10 +49,6 @@ def _forward_backward_pipelining_with_interleaving(
"""
if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'")
# TODO (mkozuki): Sanity check the following condition.
if len(batch) != len(model):
msg = f"`batch` and `model` must have the same number of elements. Actual {len(batch)} and {len(model)}"
raise RuntimeError(msg)
num_model_chunks = len(model)
input_tensors = [[] for _ in range(num_model_chunks)]
......@@ -122,7 +118,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(
forward_step_func,
get_kth_microbatch(batch[model_chunk_id], curr_iters[model_chunk_id]),
get_kth_microbatch(batch, curr_iters[model_chunk_id]),
model[model_chunk_id],
input_tensor,
losses_reduced,
......
......@@ -125,10 +125,7 @@ def forward_backward_func_template(
torch.optim.Adam(_param_groups)
tensor_shape = [batch_size // parallel_state.get_data_parallel_world_size(), hidden_size]
if virtual_pipeline_model_parallel_size is None:
batch = (torch.randn(tensor_shape).cuda(),)
else:
batch = [(torch.randn(tensor_shape).cuda(),) for _ in range(virtual_pipeline_model_parallel_size)]
batch = (torch.randn(tensor_shape).cuda(),)
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment