"docs/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "40022d075a0413a49e7802ec2bbc7563413a1cb5"
Unverified Commit 41fa672d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Enable `requires_grad` on input embedding to train on top of frozen layers (#21598)

* v1

* make fixup

* add more methods
parent 8c502662
...@@ -1148,6 +1148,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1148,6 +1148,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return False return False
return True return True
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
def disable_input_require_grads(self):
"""
Removes the `_require_grads_hook`.
"""
self._require_grads_hook.remove()
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
""" """
Returns the model's input embeddings. Returns the model's input embeddings.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment