Unverified Commit 1bd4c9e9 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

remvoe one line as requested by gc team (#3077)

remvoe one line
parent eb2ef316
......@@ -340,11 +340,10 @@ def main():
return examples
if jax.process_index() == 0:
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment