Unverified Commit cb47293e authored by Shanmugam Ramasamy's avatar Shanmugam Ramasamy Committed by GitHub
Browse files

Patching clip model to create mask tensor on the device (#22711)



* Patching clip model to create mask tensor on the device

* Addressing PR's comments

* Addressing PR's comments

* Addressing PR's comments

---------
Co-authored-by: default avatarShanmugam Ramasamy <shanmugamr@shanmugamr-mlt.client.nvidia.com>
parent 2da73f63
......@@ -714,8 +714,8 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
......@@ -752,11 +752,11 @@ class CLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
......
......@@ -726,8 +726,8 @@ class CLIPSegTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIPSeg's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
......@@ -764,11 +764,11 @@ class CLIPSegTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
......
......@@ -1108,8 +1108,8 @@ class GroupViTTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
......@@ -1146,11 +1146,11 @@ class GroupViTTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment