"ts/vscode:/vscode.git/clone" did not exist on "a16e570ddb6e976a78fc1234a0697019fd836ed1"
Unverified Commit b9b70b0e authored by Lars Mennen's avatar Lars Mennen Committed by GitHub
Browse files

Patch for FlanT5-XXL 8bit support (#20760)

* Workaround for #20287: FlanT5-XXL 8bit support

* Make fix-copies

* revert unrelated change

* Dont apply to longt5 and switch transformers
parent fe9152f6
......@@ -281,7 +281,6 @@ class LongT5DenseActDense(nn.Module):
return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->LongT5
class LongT5DenseGatedActDense(nn.Module):
def __init__(self, config: LongT5Config):
super().__init__()
......
......@@ -278,7 +278,7 @@ class SwitchTransformersDenseActDense(nn.Module):
return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->SwitchTransformers
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class SwitchTransformersDenseGatedActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
......
......@@ -308,6 +308,12 @@ class T5DenseGatedActDense(nn.Module):
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
if hidden_states.dtype != self.wo.weight.dtype:
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment