Unverified Commit 58434879 authored by Partho's avatar Partho Committed by GitHub
Browse files

Renamed variables from single letter to better naming (#449)

* renamed variable names

q -> query
k -> key
v -> value
b -> batch
c -> channel
h -> height
w -> weight

* rename variable names

missed some in the initial commit

* renamed more variable names

As per  code review suggestions, renamed x -> hidden_states and x_in -> residual

* fixed minor typo
parent 5adb0a7b
...@@ -137,18 +137,18 @@ class SpatialTransformer(nn.Module): ...@@ -137,18 +137,18 @@ class SpatialTransformer(nn.Module):
for block in self.transformer_blocks: for block in self.transformer_blocks:
block._set_attention_slice(slice_size) block._set_attention_slice(slice_size)
def forward(self, x, context=None): def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape batch, channel, height, weight = hidden_states.shape
x_in = x residual = hidden_states
x = self.norm(x) hidden_states = self.norm(hidden_states)
x = self.proj_in(x) hidden_states = self.proj_in(hidden_states)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
for block in self.transformer_blocks: for block in self.transformer_blocks:
x = block(x, context=context) hidden_states = block(hidden_states, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
x = self.proj_out(x) hidden_states = self.proj_out(hidden_states)
return x + x_in return hidden_states + residual
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
...@@ -192,12 +192,12 @@ class BasicTransformerBlock(nn.Module): ...@@ -192,12 +192,12 @@ class BasicTransformerBlock(nn.Module):
self.attn1._slice_size = slice_size self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size self.attn2._slice_size = slice_size
def forward(self, x, context=None): def forward(self, hidden_states, context=None):
x = x.contiguous() if x.device.type == "mps" else x hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
x = self.attn1(self.norm1(x)) + x hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
x = self.attn2(self.norm2(x), context=context) + x hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
x = self.ff(self.norm3(x)) + x hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return x return hidden_states
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
...@@ -247,22 +247,22 @@ class CrossAttention(nn.Module): ...@@ -247,22 +247,22 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def forward(self, x, context=None, mask=None): def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = x.shape batch_size, sequence_length, dim = hidden_states.shape
q = self.to_q(x) query = self.to_q(hidden_states)
context = context if context is not None else x context = context if context is not None else hidden_states
k = self.to_k(context) key = self.to_k(context)
v = self.to_v(context) value = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q) query = self.reshape_heads_to_batch_dim(query)
k = self.reshape_heads_to_batch_dim(k) key = self.reshape_heads_to_batch_dim(key)
v = self.reshape_heads_to_batch_dim(v) value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used # TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of # attention, what we cannot get enough of
hidden_states = self._attention(q, k, v, sequence_length, dim) hidden_states = self._attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states) return self.to_out(hidden_states)
...@@ -308,8 +308,8 @@ class FeedForward(nn.Module): ...@@ -308,8 +308,8 @@ class FeedForward(nn.Module):
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x): def forward(self, hidden_states):
return self.net(x) return self.net(hidden_states)
# feedforward # feedforward
...@@ -326,6 +326,6 @@ class GEGLU(nn.Module): ...@@ -326,6 +326,6 @@ class GEGLU(nn.Module):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x): def forward(self, hidden_states):
x, gate = self.proj(x).chunk(2, dim=-1) hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return x * F.gelu(gate) return hidden_states * F.gelu(gate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment