Unverified Commit 6949b893 authored by Ivan Kobzarev's avatar Ivan Kobzarev Committed by GitHub
Browse files

[android] android gradle project for ops (#2897)



* [android] android gradle project for ops

* Change CMakeLists to latest PyTorch

* Use mobilenet_v3 models for detection

Don't need to have two variants of the model anymore, but I'm not removing it for now

* Fix orientation when angle = 0

* [android][test_app] Fix YUV decoding

* Use smaller version of mobilenet model

* Divide inputs by 255 again

* [android] assets mobilenetv3
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent aa264980
local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.externalNativeBuild
build
allprojects {
buildscript {
ext {
minSdkVersion = 21
targetSdkVersion = 28
compileSdkVersion = 28
buildToolsVersion = '28.0.3'
coreVersion = "1.2.0"
extJUnitVersion = "1.1.1"
runnerVersion = "1.2.0"
rulesVersion = "1.2.0"
junitVersion = "4.12"
androidSupportAppCompatV7Version = "28.0.0"
fbjniJavaOnlyVersion = "0.0.3"
soLoaderNativeLoaderVersion = "0.8.0"
}
repositories {
google()
mavenCentral()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.3.2'
classpath "com.jfrog.bintray.gradle:gradle-bintray-plugin:${GRADLE_BINTRAY_PLUGIN_VERSION}"
classpath "com.github.dcendents:android-maven-gradle-plugin:${ANDROID_MAVEN_GRADLE_PLUGIN_VERSION}"
classpath "org.jfrog.buildinfo:build-info-extractor-gradle:4.9.8"
}
}
repositories {
google()
jcenter()
}
}
ext.deps = [
jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
]
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
VERSION_NAME=0.0.1-SNAPSHOT
GROUP=org.pytorch
MAVEN_GROUP=org.pytorch
POM_URL=https://github.com/pytorch/vision/
POM_SCM_URL=https://github.com/pytorch/vision.git
POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/vision
POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/vision.git
POM_LICENSE_NAME=BSD 3-Clause
POM_LICENSE_URL=https://github.com/pytorch/vision/blob/master/LICENSE
POM_ISSUES_URL=https://github.com/pytorch/vision/issues
POM_LICENSE_DIST=repo
POM_DEVELOPER_ID=pytorch
POM_DEVELOPER_NAME=pytorch
syncWithMavenCentral=true
GRADLE_BINTRAY_PLUGIN_VERSION=1.8.0
GRADLE_VERSIONS_PLUGIN_VERSION=0.15.0
ANDROID_MAVEN_GRADLE_PLUGIN_VERSION=2.1
# Gradle internals
android.useAndroidX=true
android.enableJetifier=true
testAppAllVariantsEnabled=false
org.gradle.jvmargs=-Xmx4096m
apply plugin: 'com.github.dcendents.android-maven'
version = VERSION_NAME
group = GROUP
project.archivesBaseName = POM_ARTIFACT_ID
install {
repositories.mavenInstaller {
pom.project {
name POM_NAME
artifactId POM_ARTIFACT_ID
packaging POM_PACKAGING
description POM_DESCRIPTION
url projectUrl
scm {
url scmUrl
connection scmConnection
developerConnection scmDeveloperConnection
}
licenses {
license {
name = POM_LICENSE_NAME
url = POM_LICENSE_URL
distribution = POM_LICENSE_DIST
}
}
developers {
developer {
id developerId
name developerName
}
}
}
}
}
import java.nio.file.Files
import java.nio.file.Paths
import java.io.FileOutputStream
import java.util.zip.ZipFile
// Android tasks for Javadoc and sources.jar generation
afterEvaluate { project ->
if (POM_PACKAGING == 'aar') {
task androidJavadoc(type: Javadoc, dependsOn: assembleDebug) {
source += files(android.sourceSets.main.java.srcDirs)
failOnError false
// This task will try to compile *everything* it finds in the above directory and
// will choke on text files it doesn't understand.
exclude '**/BUCK'
exclude '**/*.md'
}
task androidJavadocJar(type: Jar, dependsOn: androidJavadoc) {
classifier = 'javadoc'
from androidJavadoc.destinationDir
}
task androidSourcesJar(type: Jar) {
classifier = 'sources'
from android.sourceSets.main.java.srcDirs
}
android.libraryVariants.all { variant ->
def name = variant.name.capitalize()
task "jar${name}"(type: Jar, dependsOn: variant.javaCompileProvider) {
from variant.javaCompileProvider.get().destinationDir
}
androidJavadoc.doFirst {
classpath += files(android.bootClasspath)
classpath += files(variant.javaCompileProvider.get().classpath.files)
// This is generated by `assembleDebug` and holds the JARs generated by the APT.
classpath += fileTree(dir: "$buildDir/intermediates/bundles/debug/", include: '**/*.jar')
// Process AAR dependencies
def aarDependencies = classpath.filter { it.name.endsWith('.aar') }
classpath -= aarDependencies
aarDependencies.each { aar ->
// Extract classes.jar from the AAR dependency, and add it to the javadoc classpath
def outputPath = "$buildDir/tmp/aarJar/${aar.name.replace('.aar', '.jar')}"
classpath += files(outputPath)
// Use a task so the actual extraction only happens before the javadoc task is run
dependsOn task(name: "extract ${aar.name}").doLast {
extractEntry(aar, 'classes.jar', outputPath)
}
}
}
}
artifacts.add('archives', androidJavadocJar)
artifacts.add('archives', androidSourcesJar)
}
if (POM_PACKAGING == 'jar') {
task javadocJar(type: Jar, dependsOn: javadoc) {
classifier = 'javadoc'
from javadoc.destinationDir
}
task sourcesJar(type: Jar, dependsOn: classes) {
classifier = 'sources'
from sourceSets.main.allSource
}
artifacts.add('archives', javadocJar)
artifacts.add('archives', sourcesJar)
}
}
// Utility method to extract only one entry in a zip file
private def extractEntry(archive, entryPath, outputPath) {
if (!archive.exists()) {
throw new GradleException("archive $archive not found")
}
def zip = new ZipFile(archive)
zip.entries().each {
if (it.name == entryPath) {
def path = Paths.get(outputPath)
if (!Files.exists(path)) {
Files.createDirectories(path.getParent())
Files.copy(zip.getInputStream(it), path)
}
}
}
zip.close()
}
apply plugin: 'com.jfrog.bintray'
def getBintrayUsername() {
return project.hasProperty('bintrayUsername') ? property('bintrayUsername') : System.getenv('BINTRAY_USERNAME')
}
def getBintrayApiKey() {
return project.hasProperty('bintrayApiKey') ? property('bintrayApiKey') : System.getenv('BINTRAY_API_KEY')
}
def getBintrayGpgPassword() {
return project.hasProperty('bintrayGpgPassword') ? property('bintrayGpgPassword') : System.getenv('BINTRAY_GPG_PASSWORD')
}
def getMavenCentralUsername() {
return project.hasProperty('mavenCentralUsername') ? property('mavenCentralUsername') : System.getenv('MAVEN_CENTRAL_USERNAME')
}
def getMavenCentralPassword() {
return project.hasProperty('mavenCentralPassword') ? property('mavenCentralPassword') : System.getenv('MAVEN_CENTRAL_PASSWORD')
}
def shouldSyncWithMavenCentral() {
return project.hasProperty('syncWithMavenCentral') ? property('syncWithMavenCentral').toBoolean() : false
}
def dryRunOnly() {
return project.hasProperty('dryRun') ? property('dryRun').toBoolean() : false
}
bintray {
user = getBintrayUsername()
key = getBintrayApiKey()
override = false
configurations = ['archives']
pkg {
repo = bintrayRepo
userOrg = bintrayUserOrg
name = bintrayName
desc = bintrayDescription
websiteUrl = projectUrl
issueTrackerUrl = issuesUrl
vcsUrl = scmUrl
licenses = [ POM_LICENSE_NAME ]
dryRun = dryRunOnly()
override = false
publish = true
publicDownloadNumbers = true
version {
name = versionName
desc = bintrayDescription
gpg {
sign = true
passphrase = getBintrayGpgPassword()
}
mavenCentralSync {
sync = shouldSyncWithMavenCentral()
user = getMavenCentralUsername()
password = getMavenCentralPassword()
close = '1' // If set to 0, you have to manually click release
}
}
}
}
apply plugin: 'signing'
version = VERSION_NAME
group = MAVEN_GROUP
def isReleaseBuild() {
return !VERSION_NAME.contains('SNAPSHOT')
}
def getReleaseRepositoryUrl() {
return hasProperty('RELEASE_REPOSITORY_URL') ? RELEASE_REPOSITORY_URL
: "https://oss.sonatype.org/service/local/staging/deploy/maven2/"
}
def getSnapshotRepositoryUrl() {
return hasProperty('SNAPSHOT_REPOSITORY_URL') ? SNAPSHOT_REPOSITORY_URL
: "https://oss.sonatype.org/content/repositories/snapshots/"
}
def getRepositoryUsername() {
return hasProperty('SONATYPE_NEXUS_USERNAME') ? SONATYPE_NEXUS_USERNAME : ""
}
def getRepositoryPassword() {
return hasProperty('SONATYPE_NEXUS_PASSWORD') ? SONATYPE_NEXUS_PASSWORD : ""
}
def getHttpProxyHost() {
return project.properties['systemProp.http.proxyHost']
}
def getHttpProxyPort() {
return project.properties['systemProp.http.proxyPort']
}
def needProxy() {
return (getHttpProxyHost() != null) && (getHttpProxyPort() != null)
}
afterEvaluate { project ->
uploadArchives {
repositories {
mavenDeployer {
beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) }
pom.groupId = MAVEN_GROUP
pom.artifactId = POM_ARTIFACT_ID
pom.version = VERSION_NAME
repository(url: getReleaseRepositoryUrl()) {
authentication(userName: getRepositoryUsername(), password: getRepositoryPassword())
if (needProxy()) {
proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http')
}
}
snapshotRepository(url: getSnapshotRepositoryUrl()) {
authentication(userName: getRepositoryUsername(), password: getRepositoryPassword())
if (needProxy()) {
proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http')
}
}
pom.project {
name POM_NAME
packaging POM_PACKAGING
description POM_DESCRIPTION
url POM_URL
scm {
url POM_SCM_URL
connection POM_SCM_CONNECTION
developerConnection POM_SCM_DEV_CONNECTION
}
licenses {
license {
name POM_LICENSE_NAME
url POM_LICENSE_URL
distribution POM_LICENSE_DIST
}
}
developers {
developer {
id POM_DEVELOPER_ID
name POM_DEVELOPER_NAME
}
}
}
}
}
}
signing {
required { isReleaseBuild() && gradle.taskGraph.hasTask('uploadArchives') }
sign configurations.archives
}
}
apply from: rootProject.file('gradle_scripts/android_tasks.gradle')
apply from: rootProject.file('gradle_scripts/release_bintray.gradle')
apply from: rootProject.file('gradle_scripts/gradle_maven_push.gradle')
ext {
bintrayRepo = 'maven'
bintrayUserOrg = 'pytorch'
bintrayName = "${GROUP}:${POM_ARTIFACT_ID}"
bintrayDescription = POM_DESCRIPTION
projectUrl = POM_URL
issuesUrl = POM_ISSUES_URL
scmUrl = POM_SCM_URL
scmConnection = POM_SCM_CONNECTION
scmDeveloperConnection = POM_SCM_DEV_CONNECTION
publishedGroupId = GROUP
libraryName = 'torchvision'
artifact = 'torchvision'
developerId = POM_DEVELOPER_ID
developerName = POM_DEVELOPER_NAME
versionName = VERSION_NAME
projectLicenses = {
license = {
name = POM_LICENSE_NAME
url = POM_LICENSE_URL
distribution = POM_LICENSE_DIST
}
}
}
apply from: rootProject.file('gradle_scripts/android_maven_install.gradle')
apply from: rootProject.file('gradle_scripts/bintray.gradle')
cmake_minimum_required(VERSION 3.4.1)
set(TARGET torchvision_ops)
project(${TARGET} CXX)
set(CMAKE_CXX_STANDARD 14)
string(APPEND CMAKE_CXX_FLAGS " -DMOBILE")
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
set(root_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
file(GLOB VISION_SRCS
../../torchvision/csrc/ops/cpu/*.h
../../torchvision/csrc/ops/cpu/*.cpp
../../torchvision/csrc/ops/*.h
../../torchvision/csrc/ops/*.cpp)
add_library(${TARGET} SHARED
${VISION_SRCS}
)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_INCLUDE_DIRS_CSRC "${build_DIR}/pytorch_android*.aar/headers/torch/csrc/api/include")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
target_compile_options(${TARGET} PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_include_directories(${TARGET} PRIVATE
${PYTORCH_INCLUDE_DIRS}
${PYTORCH_INCLUDE_DIRS_CSRC}
)
target_link_libraries(${TARGET} PRIVATE
${PYTORCH_LIBRARY}
${FBJNI_LIBRARY}
)
apply plugin: 'com.android.library'
apply plugin: 'maven'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 0
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}
dependencies {
implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
extractForNativeBuild 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
// For testing: deps on local aar files
//implementation(name: 'pytorch_android-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
//implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
apply from: rootProject.file('gradle_scripts/release.gradle')
task sourcesJar(type: Jar) {
from android.sourceSets.main.java.srcDirs
classifier = 'sources'
}
artifacts.add('archives', sourcesJar)
POM_NAME=torchvision ops
POM_DESCRIPTION=torchvision ops
POM_ARTIFACT_ID=torchvision_ops
POM_PACKAGING=aar
<manifest package="org.pytorch.torchvision.ops" />
include ':ops', ':test_app'
project(':ops').projectDir = file('ops')
project(':test_app').projectDir = file('test_app/app')
apply plugin: 'com.android.application'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileOptions {
sourceCompatibility 1.8
targetCompatibility 1.8
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
applicationId "org.pytorch.testapp"
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"
ndk {
abiFilters ABI_FILTERS.split(",")
}
externalNativeBuild {
cmake {
abiFilters ABI_FILTERS.split(",")
arguments "-DANDROID_STL=c++_shared"
}
}
buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"")
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{3, 96, 96}")
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
flavorDimensions "model", "activity", "build"
productFlavors {
frcnnMnetv3 {
dimension "model"
applicationIdSuffix ".frcnnMnetv3"
buildConfigField("String", "MODULE_ASSET_NAME", "\"frcnn_mnetv3.pt\"")
addManifestPlaceholders([APP_NAME: "TV_FRCNN_MNETV3"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-frcnn-mnetv3\"")
}
camera {
dimension "activity"
addManifestPlaceholders([APP_NAME: "TV_CAMERA_FRCNN"])
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
}
base {
dimension "activity"
}
aar {
dimension "build"
}
nightly {
dimension "build"
}
local {
dimension "build"
}
}
packagingOptions {
doNotStrip '**.so'
}
// Filtering for CI
if (!testAppAllVariantsEnabled.toBoolean()) {
variantFilter { variant ->
def names = variant.flavors*.name
if (names.contains("aar")) {
setIgnore(true)
}
}
}
}
tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
}
dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.8.0'
localImplementation project(':ops')
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
implementation 'org.pytorch:torchvision_ops:0.0.1-SNAPSHOT'
aarImplementation(name: 'pytorch_android-release', ext: 'aar')
aarImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
def camerax_version = "1.0.0-alpha05"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation 'com.google.android.material:material:1.0.0-beta01'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.testapp">
<application
android:allowBackup="true"
android:label="${APP_NAME}"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name="${MAIN_ACTIVITY}">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.CAMERA" />
</manifest>
package org.pytorch.testapp;
class BBox {
public final float score;
public final float x0;
public final float y0;
public final float x1;
public final float y1;
public BBox(float score, float x0, float y0, float x1, float y1) {
this.score = score;
this.x0 = x0;
this.y0 = y0;
this.x1 = x1;
this.y1 = y1;
}
@Override
public String toString() {
return String.format("Box{score=%f x0=%f y0=%f x1=%f y1=%f", score, x0, y0, x1, y1);
}
}
package org.pytorch.testapp;
import android.Manifest;
import android.content.Context;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Rect;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.DisplayMetrics;
import android.util.Log;
import android.util.Size;
import android.view.TextureView;
import android.view.ViewStub;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
public class CameraActivity extends AppCompatActivity {
private static final float BBOX_SCORE_DRAW_THRESHOLD = 0.5f;
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private static final int RGB_MAX_CHANNEL_VALUE = 262143;
private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni");
NativeLoader.loadLibrary("torchvision_ops");
}
private Bitmap mInputTensorBitmap;
private Bitmap mBitmap;
private Canvas mCanvas;
private long mLastAnalysisResultTime;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
protected Handler mUIHandler;
private TextView mTextView;
private ImageView mCameraOverlay;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
private Paint mBboxPaint;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_camera);
mTextView = findViewById(R.id.text);
mCameraOverlay = findViewById(R.id.camera_overlay);
mUIHandler = new Handler(getMainLooper());
startBackgroundThread();
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
} else {
setupCameraX();
}
mBboxPaint = new Paint();
mBboxPaint.setAntiAlias(true);
mBboxPaint.setDither(true);
mBboxPaint.setColor(Color.GREEN);
}
@Override
protected void onPostCreate(@Nullable Bundle savedInstanceState) {
super.onPostCreate(savedInstanceState);
startBackgroundThread();
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread("ModuleActivity");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error on stopping background thread", e);
}
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
Toast.makeText(
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
.show();
finish();
} else {
setupCameraX();
}
}
}
private void setupCameraX() {
final TextureView textureView =
((ViewStub) findViewById(R.id.camera_texture_view_stub))
.inflate()
.findViewById(R.id.texture_view);
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
new Preview.OnPreviewOutputUpdateListener() {
@Override
public void onUpdated(Preview.PreviewOutput output) {
textureView.setSurfaceTexture(output.getSurfaceTexture());
}
});
final DisplayMetrics displayMetrics = new DisplayMetrics();
getWindowManager().getDefaultDisplay().getMetrics(displayMetrics);
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(displayMetrics.widthPixels, displayMetrics.heightPixels))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
new ImageAnalysis.Analyzer() {
@Override
public void analyze(ImageProxy image, int rotationDegrees) {
if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
return;
}
final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
if (result != null) {
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
CameraActivity.this.runOnUiThread(
new Runnable() {
@Override
public void run() {
CameraActivity.this.handleResult(result);
}
});
}
}
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
private static int clamp0255(int x) {
if (x > 255) {
return 255;
}
return x < 0 ? 0 : x;
}
protected void fillInputTensorBuffer(
ImageProxy image,
int rotationDegrees,
FloatBuffer inputTensorBuffer) {
if (mInputTensorBitmap == null) {
final int tensorSize = Math.min(image.getWidth(), image.getHeight());
mInputTensorBitmap = Bitmap.createBitmap(tensorSize, tensorSize, Bitmap.Config.ARGB_8888);
}
ImageProxy.PlaneProxy[] planes = image.getPlanes();
ImageProxy.PlaneProxy Y = planes[0];
ImageProxy.PlaneProxy U = planes[1];
ImageProxy.PlaneProxy V = planes[2];
ByteBuffer yBuffer = Y.getBuffer();
ByteBuffer uBuffer = U.getBuffer();
ByteBuffer vBuffer = V.getBuffer();
final int imageWidth = image.getWidth();
final int imageHeight = image.getHeight();
final int tensorSize = Math.min(imageWidth, imageHeight);
int widthAfterRtn = imageWidth;
int heightAfterRtn = imageHeight;
boolean oddRotation = rotationDegrees == 90 || rotationDegrees == 270;
if (oddRotation) {
widthAfterRtn = imageHeight;
heightAfterRtn = imageWidth;
}
int minSizeAfterRtn = Math.min(heightAfterRtn, widthAfterRtn);
int cropWidthAfterRtn = minSizeAfterRtn;
int cropHeightAfterRtn = minSizeAfterRtn;
int cropWidthBeforeRtn = cropWidthAfterRtn;
int cropHeightBeforeRtn = cropHeightAfterRtn;
if (oddRotation) {
cropWidthBeforeRtn = cropHeightAfterRtn;
cropHeightBeforeRtn = cropWidthAfterRtn;
}
int offsetX = (int) ((imageWidth - cropWidthBeforeRtn) / 2.f);
int offsetY = (int) ((imageHeight - cropHeightBeforeRtn) / 2.f);
int yRowStride = Y.getRowStride();
int yPixelStride = Y.getPixelStride();
int uvRowStride = U.getRowStride();
int uvPixelStride = U.getPixelStride();
float scale = cropWidthAfterRtn / tensorSize;
int yIdx, uvIdx, yi, ui, vi;
final int channelSize = tensorSize * tensorSize;
for (int y = 0; y < tensorSize; y++) {
for (int x = 0; x < tensorSize; x++) {
final int centerCropX = (int) Math.floor(x * scale);
final int centerCropY = (int) Math.floor(y * scale);
int srcX = centerCropX + offsetX;
int srcY = centerCropY + offsetY;
if (rotationDegrees == 90) {
srcX = offsetX + centerCropY;
srcY = offsetY + (minSizeAfterRtn - 1) - centerCropX;
} else if (rotationDegrees == 180) {
srcX = offsetX + (minSizeAfterRtn - 1) - centerCropX;
srcY = offsetY + (minSizeAfterRtn - 1) - centerCropY;
} else if (rotationDegrees == 270) {
srcX = offsetX + (minSizeAfterRtn - 1) - centerCropY;
srcY = offsetY + centerCropX;
}
yIdx = srcY * yRowStride + srcX * yPixelStride;
uvIdx = (srcY >> 1) * uvRowStride + (srcX >> 1) * uvPixelStride;
yi = yBuffer.get(yIdx) & 0xff;
ui = uBuffer.get(uvIdx) & 0xff;
vi = vBuffer.get(uvIdx) & 0xff;
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
int a0 = 1192 * yi;
int ri = (a0 + 1634 * vi);
int gi = (a0 - 833 * vi - 400 * ui);
int bi = (a0 + 2066 * ui);
ri = ri > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (ri < 0 ? 0 : ri);
gi = gi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (gi < 0 ? 0 : gi);
bi = bi > RGB_MAX_CHANNEL_VALUE ? RGB_MAX_CHANNEL_VALUE : (bi < 0 ? 0 : bi);
final int color = 0xff000000 | ((ri << 6) & 0xff0000) | ((gi >> 2) & 0xff00) | ((bi >> 10) & 0xff);
mInputTensorBitmap.setPixel(x, y, color);
inputTensorBuffer.put(0 * channelSize + y * tensorSize + x, clamp0255(ri >> 10) / 255.f);
inputTensorBuffer.put(1 * channelSize + y * tensorSize + x, clamp0255(gi >> 10) / 255.f);
inputTensorBuffer.put(2 * channelSize + y * tensorSize + x, clamp0255(bi >> 10) / 255.f);
}
}
}
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}
@WorkerThread
@Nullable
protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
final int tensorSize = Math.min(image.getWidth(), image.getHeight());
if (mModule == null) {
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * tensorSize * tensorSize);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{3, tensorSize, tensorSize});
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
mModule = Module.load(modelFileAbsoluteFilePath);
}
final long startTime = SystemClock.elapsedRealtime();
fillInputTensorBuffer(image, rotationDegrees, mInputTensorBuffer);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor));
final IValue out1 = outputTuple.toTuple()[1];
final Map<String, IValue> map = out1.toList()[0].toDictStringKey();
float[] boxesData = new float[]{};
float[] scoresData = new float[]{};
final List<BBox> bboxes = new ArrayList<>();
if (map.containsKey("boxes")) {
final Tensor boxesTensor = map.get("boxes").toTensor();
final Tensor scoresTensor = map.get("scores").toTensor();
boxesData = boxesTensor.getDataAsFloatArray();
scoresData = scoresTensor.getDataAsFloatArray();
final int n = scoresData.length;
for (int i = 0; i < n; i++) {
final BBox bbox = new BBox(
scoresData[i],
boxesData[4 * i + 0],
boxesData[4 * i + 1],
boxesData[4 * i + 2],
boxesData[4 * i + 3]
);
android.util.Log.i(TAG, String.format("Forward result %d: %s", i, bbox));
bboxes.add(bbox);
}
} else {
android.util.Log.i(TAG, "Forward result empty");
}
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(tensorSize, bboxes, moduleForwardDuration, analysisDuration);
}
@UiThread
protected void handleResult(Result result) {
final int W = mCameraOverlay.getMeasuredWidth();
final int H = mCameraOverlay.getMeasuredHeight();
final int size = Math.min(W, H);
final int offsetX = (W - size) / 2;
final int offsetY = (H - size) / 2;
float scaleX = (float) size / result.tensorSize;
float scaleY = (float) size / result.tensorSize;
if (mBitmap == null) {
mBitmap = Bitmap.createBitmap(W, H, Bitmap.Config.ARGB_8888);
mCanvas = new Canvas(mBitmap);
}
mCanvas.drawBitmap(
mInputTensorBitmap,
new Rect(0, 0, result.tensorSize, result.tensorSize),
new Rect(offsetX, offsetY, offsetX + size, offsetY + size),
null
);
for (final BBox bbox : result.bboxes) {
if (bbox.score < BBOX_SCORE_DRAW_THRESHOLD) {
continue;
}
float c_x0 = offsetX + scaleX * bbox.x0;
float c_y0 = offsetY + scaleY * bbox.y0;
float c_x1 = offsetX + scaleX * bbox.x1;
float c_y1 = offsetY + scaleY * bbox.y1;
mCanvas.drawLine(c_x0, c_y0, c_x1, c_y0, mBboxPaint);
mCanvas.drawLine(c_x1, c_y0, c_x1, c_y1, mBboxPaint);
mCanvas.drawLine(c_x1, c_y1, c_x0, c_y1, mBboxPaint);
mCanvas.drawLine(c_x0, c_y1, c_x0, c_y0, mBboxPaint);
mCanvas.drawText(String.format("%.2f", bbox.score), c_x0, c_y0, mBboxPaint);
}
mCameraOverlay.setImageBitmap(mBitmap);
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
Log.i(TAG, message);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}
package org.pytorch.testapp;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.widget.TextView;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import java.nio.FloatBuffer;
import java.util.Map;
public class MainActivity extends AppCompatActivity {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni");
NativeLoader.loadLibrary("torchvision_ops");
}
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private TextView mTextView;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
private final Runnable mModuleForwardRunnable =
new Runnable() {
@Override
public void run() {
final Result result = doModuleForward();
runOnUiThread(
() -> {
handleResult(result);
if (mBackgroundHandler != null) {
mBackgroundHandler.post(mModuleForwardRunnable);
}
});
}
};
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mTextView = findViewById(R.id.text);
startBackgroundThread();
mBackgroundHandler.post(mModuleForwardRunnable);
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread(TAG + "_bg");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error stopping background thread", e);
}
}
@WorkerThread
@Nullable
protected Result doModuleForward() {
if (mModule == null) {
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
long numElements = 1;
for (int i = 0; i < shape.length; i++) {
numElements *= shape[i];
}
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
PyTorchAndroid.setNumThreads(1);
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}
final long startTime = SystemClock.elapsedRealtime();
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final IValue outputTuple = mModule.forward(IValue.listFrom(mInputTensor));
final IValue[] outputArray = outputTuple.toTuple();
final IValue out0 = outputArray[0];
final Map<String, IValue> map = out0.toDictStringKey();
if (map.containsKey("boxes")) {
final Tensor boxes = map.get("boxes").toTensor();
final Tensor scores = map.get("scores").toTensor();
final float[] boxesData = boxes.getDataAsFloatArray();
final float[] scoresData = scores.getDataAsFloatArray();
final int n = scoresData.length;
for (int i = 0; i < n; i++) {
android.util.Log.i(TAG,
String.format("Forward result %d: score %f box:(%f, %f, %f, %f)",
scoresData[i],
boxesData[4 * i + 0],
boxesData[4 * i + 1],
boxesData[4 * i + 2],
boxesData[4 * i + 3]));
}
} else {
android.util.Log.i(TAG, "Forward result empty");
}
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(new float[]{}, moduleForwardDuration, analysisDuration);
}
static class Result {
private final float[] scores;
private final long totalDuration;
private final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}
@UiThread
protected void handleResult(Result result) {
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}
package org.pytorch.testapp;
import java.util.List;
class Result {
public final int tensorSize;
public final List<BBox> bboxes;
public final long totalDuration;
public final long moduleForwardDuration;
public Result(int tensorSize, List<BBox> bboxes, long moduleForwardDuration, long totalDuration) {
this.tensorSize = tensorSize;
this.bboxes = bboxes;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment