Commit 91d33c79 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Fix issue on pipelines where pytorch's tensors are not copied on the user-specified GPU device.


Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent c301faa9
...@@ -335,13 +335,13 @@ class Pipeline(_ScikitCompat): ...@@ -335,13 +335,13 @@ class Pipeline(_ScikitCompat):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework
self.device = device self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler() self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling # Special handling
if self.device >= 0 and self.framework == "pt": if self.framework == "pt" and self.device.type == "cuda":
self.model = self.model.to("cuda:{}".format(self.device)) self.model = self.model.to(self.device)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" """
...@@ -385,11 +385,19 @@ class Pipeline(_ScikitCompat): ...@@ -385,11 +385,19 @@ class Pipeline(_ScikitCompat):
with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)): with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
yield yield
else: else:
if self.device >= 0: if self.device.type == "cuda":
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
yield yield
def ensure_tensor_on_device(self, **inputs):
"""
Ensure PyTorch tensors are on the specified device.
:param inputs:
:return:
"""
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict: def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
""" """
Generates the input dictionary with model-specific parameters. Generates the input dictionary with model-specific parameters.
...@@ -415,16 +423,13 @@ class Pipeline(_ScikitCompat): ...@@ -415,16 +423,13 @@ class Pipeline(_ScikitCompat):
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
# Parse arguments # Parse arguments
inputs = self._args_parser(*texts, **kwargs) inputs = self._args_parser(*texts, **kwargs)
inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
)
# Encode for forward # Filter out features not available on specific models
with self.device_placement(): inputs = self.inputs_for_model(inputs)
inputs = self.tokenizer.batch_encode_plus( return self._forward(inputs)
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
)
# Filter out features not available on specific models
inputs = self.inputs_for_model(inputs)
return self._forward(inputs)
def _forward(self, inputs): def _forward(self, inputs):
""" """
...@@ -434,12 +439,15 @@ class Pipeline(_ScikitCompat): ...@@ -434,12 +439,15 @@ class Pipeline(_ScikitCompat):
Returns: Returns:
Numpy array Numpy array
""" """
if self.framework == "tf": # Encode for forward
# TODO trace model with self.device_placement():
predictions = self.model(inputs, training=False)[0] if self.framework == "tf":
else: # TODO trace model
with torch.no_grad(): predictions = self.model(inputs, training=False)[0]
predictions = self.model(**inputs)[0].cpu() else:
with torch.no_grad():
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()
return predictions.numpy() return predictions.numpy()
...@@ -534,6 +542,7 @@ class NerPipeline(Pipeline): ...@@ -534,6 +542,7 @@ class NerPipeline(Pipeline):
input_ids = tokens["input_ids"].numpy()[0] input_ids = tokens["input_ids"].numpy()[0]
else: else:
with torch.no_grad(): with torch.no_grad():
tokens = self.ensure_tensor_on_device(**tokens)
entities = self.model(**tokens)[0][0].cpu().numpy() entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens["input_ids"].cpu().numpy()[0] input_ids = tokens["input_ids"].cpu().numpy()[0]
...@@ -710,7 +719,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -710,7 +719,7 @@ class QuestionAnsweringPipeline(Pipeline):
else: else:
with torch.no_grad(): with torch.no_grad():
# Retrieve the score for the context tokens only (removing question tokens) # Retrieve the score for the context tokens only (removing question tokens)
fw_args = {k: torch.tensor(v) for (k, v) in fw_args.items()} fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()}
start, end = self.model(**fw_args) start, end = self.model(**fw_args)
start, end = start.cpu().numpy(), end.cpu().numpy() start, end = start.cpu().numpy(), end.cpu().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment