Unverified Commit 23a04e08 authored by lsz05's avatar lsz05 Committed by GitHub
Browse files

[Fix] Support cls pooling in ModernBertPooler (#20067)


Signed-off-by: default avatarshengzhe.li <shengzhe.li@sbintuitions.co.jp>
parent 02c97d9a
...@@ -258,6 +258,7 @@ class ModernBertPooler(nn.Module): ...@@ -258,6 +258,7 @@ class ModernBertPooler(nn.Module):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias) config.classifier_bias)
self.pooling_type = config.classifier_pooling
self.act = nn.GELU() self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps, eps=config.norm_eps,
...@@ -265,7 +266,13 @@ class ModernBertPooler(nn.Module): ...@@ -265,7 +266,13 @@ class ModernBertPooler(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states pooled_output = hidden_states
if self.pooling_type == "mean":
pooled_output = pooled_output.mean(dim=0, keepdim=False) pooled_output = pooled_output.mean(dim=0, keepdim=False)
elif self.pooling_type == "cls":
pooled_output = pooled_output[0, :]
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
pooled_output = self.norm(self.act(self.dense(pooled_output))) pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output return pooled_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment