"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1ae182d9a62f178caba8d9a472f8a6aa67142702"
Unverified Commit f6fa0f0b authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

Create the arange tensor on device for enabling CUDA-Graph for Clip Encoder (#19503)



* create the arange tensor on device for enabling CUDA-Graph at higher-performace for SD

* sync
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 6cd8676c
...@@ -662,7 +662,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -662,7 +662,7 @@ class CLIPTextTransformer(nn.Module):
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[ pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
] ]
if not return_dict: if not return_dict:
......
...@@ -1134,7 +1134,7 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1134,7 +1134,7 @@ class GroupViTTextTransformer(nn.Module):
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[ pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
] ]
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment