Commit a2aa804c authored by Casper Hansen's avatar Casper Hansen
Browse files

Catch out of memory exception

parent 0091f1e2
...@@ -55,7 +55,10 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size): ...@@ -55,7 +55,10 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size):
context_time, generate_time = generate(model, input_ids, n_generate) context_time, generate_time = generate(model, input_ids, n_generate)
successful_generate = True successful_generate = True
except RuntimeError as ex: except RuntimeError as ex:
if 'cuda out of memory' in str(ex).lower():
successful_generate = False successful_generate = False
else:
raise RuntimeError(ex)
device = next(model.parameters()).device device = next(model.parameters()).device
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment