Unverified Commit ec7f8af1 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[ConvNeXT] Fix drop_path_rate (#17280)

* Fix drop_path_rate

* Fix TF's drop path rate
parent a26ab95e
......@@ -209,8 +209,9 @@ class ConvNextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.stages = nn.ModuleList()
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
cur = 0
drop_path_rates = [
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):
out_chs = config.hidden_sizes[i]
......@@ -220,10 +221,9 @@ class ConvNextEncoder(nn.Module):
out_channels=out_chs,
stride=2 if i > 0 else 1,
depth=config.depths[i],
drop_path_rates=drop_path_rates[cur],
drop_path_rates=drop_path_rates[i],
)
self.stages.append(stage)
cur += config.depths[i]
prev_chs = out_chs
def forward(
......
......@@ -235,8 +235,9 @@ class TFConvNextEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.stages = []
drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
cur = 0
drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
drop_path_rates = tf.split(drop_path_rates, config.depths)
drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):
out_chs = config.hidden_sizes[i]
......@@ -246,11 +247,10 @@ class TFConvNextEncoder(tf.keras.layers.Layer):
out_channels=out_chs,
stride=2 if i > 0 else 1,
depth=config.depths[i],
drop_path_rates=drop_path_rates[cur],
drop_path_rates=drop_path_rates[i],
name=f"stages.{i}",
)
self.stages.append(stage)
cur += config.depths[i]
prev_chs = out_chs
def call(self, hidden_states, output_hidden_states=False, return_dict=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment