Unverified Commit b9b70b0e authored by Lars Mennen's avatar Lars Mennen Committed by GitHub
Browse files

Patch for FlanT5-XXL 8bit support (#20760)

* Workaround for #20287: FlanT5-XXL 8bit support

* Make fix-copies

* revert unrelated change

* Dont apply to longt5 and switch transformers
parent fe9152f6
...@@ -281,7 +281,6 @@ class LongT5DenseActDense(nn.Module): ...@@ -281,7 +281,6 @@ class LongT5DenseActDense(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->LongT5
class LongT5DenseGatedActDense(nn.Module): class LongT5DenseGatedActDense(nn.Module):
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
super().__init__() super().__init__()
......
...@@ -278,7 +278,7 @@ class SwitchTransformersDenseActDense(nn.Module): ...@@ -278,7 +278,7 @@ class SwitchTransformersDenseActDense(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->SwitchTransformers # Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class SwitchTransformersDenseGatedActDense(nn.Module): class SwitchTransformersDenseGatedActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig): def __init__(self, config: SwitchTransformersConfig):
super().__init__() super().__init__()
......
...@@ -308,6 +308,12 @@ class T5DenseGatedActDense(nn.Module): ...@@ -308,6 +308,12 @@ class T5DenseGatedActDense(nn.Module):
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
if hidden_states.dtype != self.wo.weight.dtype:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment