Unverified Commit c62b01d0 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

use _make_causal_mask in clip/vit models (#23942)

use _make_causal_mask in clip models
parent e03a9cc0
......@@ -673,6 +673,24 @@ class CLIPEncoder(nn.Module):
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
......@@ -711,12 +729,9 @@ class CLIPTextTransformer(nn.Module):
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -752,15 +767,6 @@ class CLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
@add_start_docstrings(
"""The text model from CLIP without any head or projection on top.""",
......
......@@ -683,6 +683,24 @@ class CLIPSegEncoder(nn.Module):
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class CLIPSegTextTransformer(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.__init__ with CLIP->CLIPSeg
def __init__(self, config: CLIPSegTextConfig):
......@@ -723,12 +741,9 @@ class CLIPSegTextTransformer(nn.Module):
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
bsz, seq_len = input_shape
# CLIPSeg's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -764,15 +779,6 @@ class CLIPSegTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class CLIPSegTextModel(CLIPSegPreTrainedModel):
config_class = CLIPSegTextConfig
......
......@@ -1066,6 +1066,24 @@ class GroupViTTextEncoder(nn.Module):
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder, CLIP_TEXT->GROUPVIT_TEXT
class GroupViTTextTransformer(nn.Module):
def __init__(self, config: GroupViTTextConfig):
......@@ -1105,12 +1123,9 @@ class GroupViTTextTransformer(nn.Module):
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -1146,15 +1161,6 @@ class GroupViTTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class GroupViTTextModel(GroupViTPreTrainedModel):
config_class = GroupViTTextConfig
......
......@@ -783,6 +783,24 @@ class OwlViTEncoder(nn.Module):
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class OwlViTTextTransformer(nn.Module):
def __init__(self, config: OwlViTTextConfig):
super().__init__()
......@@ -816,10 +834,10 @@ class OwlViTTextTransformer(nn.Module):
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
num_samples, seq_len = input_shape # num_samples = batch_size * num_max_text_queries
# num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries
# OWLVIT's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(num_samples, seq_len).to(hidden_states.device)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
......@@ -854,15 +872,6 @@ class OwlViTTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class OwlViTTextModel(OwlViTPreTrainedModel):
config_class = OwlViTTextConfig
......
......@@ -737,6 +737,24 @@ class XCLIPEncoder(nn.Module):
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class XCLIPTextTransformer(nn.Module):
def __init__(self, config: XCLIPTextConfig):
super().__init__()
......@@ -775,12 +793,9 @@ class XCLIPTextTransformer(nn.Module):
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
batch_size, seq_len = input_shape
# X_CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_len, hidden_states.dtype).to(
hidden_states.device
)
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
......@@ -812,15 +827,6 @@ class XCLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, batch_size, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(batch_size, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class XCLIPTextModel(XCLIPPreTrainedModel):
config_class = XCLIPTextConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment