Unverified Commit e2c39426 authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

Fixing memory leak in distributed pipeline (#724)

* Fixing memory lead in distributed pipeline

* fix mypy error
parent 681606f0
......@@ -273,14 +273,19 @@ class PartitionHandler:
pipeline_record = pipeline_record_rref.local_value()
self.run(pipeline_record)
result: Optional[Tensor] = None
if not pipeline_record.consumers:
result = microbatch.gather(pipeline_record.batches)
assert len(result) == 1
result = result[0]
gather_result = microbatch.gather(pipeline_record.batches)
assert len(gather_result) == 1
result = gather_result[0]
s0 = current_stream(result.device)
if is_cuda(s0):
# TODO. Investigate why this is needed and remove it if possible.
as_cuda(s0).synchronize()
return result
return None
# TODO: There seems to be a memory leak that is solved by following line.
# Investigate why is it needed.
del pipeline_record.batches
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment